Skip to content

Commit da711cf

Browse files
committed
refactor core.grpc
1 parent 58ba0e0 commit da711cf

File tree

1 file changed

+64
-58
lines changed

1 file changed

+64
-58
lines changed

lualib/core/grpc.lua

Lines changed: 64 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ local code = require "core.grpc.code"
33
local codename = require "core.grpc.codename"
44
local transport = require "core.http.transport"
55
local pb = require "pb"
6+
67
local assert = assert
78
local pack = string.pack
89
local unpack = string.unpack
@@ -11,22 +12,23 @@ local sub = string.sub
1112
local concat = table.concat
1213
local tonumber = tonumber
1314
local setmetatable = setmetatable
15+
1416
local M = {}
1517

1618
local HDR_SIZE<const> = 5
1719
local BODY_START<const> = HDR_SIZE+1
1820
local MAX_LEN<const> = 4*1024*1024
1921

20-
---@param stream core.http.h2stream
21-
---@param read fun(stream:core.http.h2stream, timeout:number):string?, string?
22+
---@param h2stream core.http.h2stream
2223
---@param is_server boolean
23-
---@param timeout number
24+
---@param timeout number?
2425
---@return string?, string? error
25-
local function read_body(stream, read, is_server, timeout)
26+
local function read_body(h2stream, is_server, timeout)
2627
local data = ""
2728
--read header
29+
local read = h2stream.read
2830
for i = 1, HDR_SIZE do
29-
local d, err = read(stream, timeout)
31+
local d, err = read(h2stream, timeout)
3032
if not d or d == "" then
3133
return nil, err
3234
end
@@ -38,7 +40,7 @@ local function read_body(stream, read, is_server, timeout)
3840
local compress, frame_size = unpack(">I1I4", data)
3941
assert(compress == 0, "grpc: compression not supported")
4042
if is_server and frame_size > MAX_LEN then
41-
stream:respond(200, {
43+
h2stream:respond(200, {
4244
['content-type'] = 'application/grpc',
4345
['grpc-status'] = code.ResourceExhausted,
4446
}, true)
@@ -49,7 +51,7 @@ local function read_body(stream, read, is_server, timeout)
4951
local buf = {data}
5052
frame_size = frame_size - #data
5153
while frame_size > 0 do
52-
local d, err = read(stream, timeout)
54+
local d, err = read(h2stream, timeout)
5355
if not d or d == "" then
5456
return nil, err
5557
end
@@ -66,11 +68,11 @@ local function dispatch(registrar)
6668
local output_name = registrar.output_name
6769
local handlers = registrar.handlers
6870
--use closure for less hash
69-
---@param stream core.http.h2stream
70-
return function(stream)
71-
local status, header = stream:readheader()
71+
---@param h2stream core.http.h2stream
72+
return function(h2stream)
73+
local status, header = h2stream:readheader()
7274
if status ~= 200 then
73-
stream:respond(200, {
75+
h2stream:respond(200, {
7476
['content-type'] = 'application/grpc',
7577
['grpc-status'] = code.Unknown,
7678
['grpc-message'] = "grpc: invalid header"
@@ -80,9 +82,9 @@ local function dispatch(registrar)
8082
local method = header[':path']
8183
local itype = input_name[method]
8284
local otype = output_name[method]
83-
local data, err = read_body(stream, stream.read, true, nil)
85+
local data, err = read_body(h2stream, true, nil)
8486
if not data then
85-
stream:close()
87+
h2stream:close()
8688
logger.warn("[core.grpc] read body failed", err)
8789
return
8890
end
@@ -91,10 +93,10 @@ local function dispatch(registrar)
9193
local outdata = pb.encode(otype, output)
9294
--payloadFormat, length, data
9395
outdata = pack(">I1I4", 0, #outdata) .. outdata
94-
stream:respond(200, {
96+
h2stream:respond(200, {
9597
['content-type'] = 'application/grpc',
9698
})
97-
stream:close(outdata, {
99+
h2stream:close(outdata, {
98100
['grpc-status'] = code.OK,
99101
})
100102
end
@@ -156,55 +158,59 @@ end
156158

157159
local alpn_protos = {"h2"}
158160

159-
---@param stream core.http.h2stream
160-
local function streaming_write_wrapper(stream, method, timeout)
161-
local itype = method.input_type
162-
local write = stream.write
163-
---@param stream core.http.h2stream
164-
---@param req table
165-
return function(stream, req)
166-
local reqdat = pb.encode(itype, req)
167-
reqdat = pack(">I1I4", 0, #reqdat) .. reqdat
168-
return write(stream, reqdat)
169-
end
161+
---@class core.grpc.streaming
162+
---@field h2stream core.http.h2stream
163+
---@field need_header boolean
164+
---@field input_type string
165+
---@field output_type string
166+
local grpc_streaming = {}
167+
local grpc_streaming_mt = { __index = grpc_streaming }
168+
169+
---@param self core.grpc.streaming
170+
function grpc_streaming:write(req)
171+
local h2stream = self.h2stream
172+
local reqdat = pb.encode(self.input_type, req)
173+
reqdat = pack(">I1I4", 0, #reqdat) .. reqdat
174+
return h2stream:write(reqdat)
170175
end
171176

172-
---@param stream core.http.h2stream
173-
local function streaming_read_wrapper(stream, method, timeout)
174-
local need_header = true
175-
local read = stream.read
176-
local otype = method.output_type
177-
return function(steam)
178-
if need_header then
179-
local status, header = stream:readheader(timeout)
180-
if not status then
181-
return nil, header
182-
end
183-
need_header = false
184-
end
185-
local data, err = read_body(stream, read, false, timeout)
186-
if not data then
187-
return nil, err
188-
end
189-
local resp = pb.decode(otype, data)
190-
if not resp then
191-
return nil, "decode error"
177+
---@param self core.grpc.streaming
178+
---@param timeout number?
179+
function grpc_streaming:read(timeout)
180+
local h2stream = self.h2stream
181+
if self.need_header then
182+
local status, header = h2stream:readheader(timeout)
183+
if not status then
184+
return nil, header
192185
end
193-
return resp, nil
186+
self.need_header = false
194187
end
188+
local data, err = read_body(h2stream, false, timeout)
189+
if not data then
190+
return nil, err
191+
end
192+
local resp = pb.decode(self.output_type, data)
193+
if not resp then
194+
return nil, "decode error"
195+
end
196+
return resp, nil
195197
end
196198

197199
---@return core.grpc.stream|nil, string|nil
198200
local function stream_call(timeout, connect, method, fullname)
199201
return function()
200202
---@class core.grpc.stream:core.http.h2stream
201-
local stream, err = connect(fullname)
202-
if not stream then
203+
local h2stream, err = connect(fullname)
204+
if not h2stream then
203205
return nil, err
204206
end
205-
stream.write = streaming_write_wrapper(stream, method, timeout)
206-
stream.read = streaming_read_wrapper(stream, method, timeout)
207-
return stream, nil
207+
local streaming = setmetatable({
208+
h2stream = h2stream,
209+
input_type = method.input_type,
210+
output_type = method.output_type,
211+
need_header = true,
212+
}, grpc_streaming_mt)
213+
return streaming, nil
208214
end
209215
end
210216

@@ -213,17 +219,17 @@ local function general_call(timeout, connect, method, fullname)
213219
local itype = method.input_type
214220
local otype = method.output_type
215221
return function(req)
216-
local stream<close>, err = connect(fullname)
217-
if not stream then
222+
local h2stream<close>, err = connect(fullname)
223+
if not h2stream then
218224
return nil, err
219225
end
220226
local reqdat = pb.encode(itype, req)
221227
reqdat = pack(">I1I4", 0, #reqdat) .. reqdat
222-
local ok, err = stream:write(reqdat)
228+
local ok, err = h2stream:write(reqdat)
223229
if not ok then
224230
return nil, err
225231
end
226-
local status, header = stream:readheader(timeout)
232+
local status, header = h2stream:readheader(timeout)
227233
if not status then
228234
return nil, header
229235
end
@@ -232,11 +238,11 @@ local function general_call(timeout, connect, method, fullname)
232238
local grpc_status = header['grpc-status']
233239
if not grpc_status then --normal header
234240
local reason
235-
body, reason = read_body(stream, stream.read, false, timeout)
241+
body, reason = read_body(h2stream, false, timeout)
236242
if not body then
237243
return nil, reason
238244
end
239-
local trailer, reason = stream:readtrailer(timeout)
245+
local trailer, reason = h2stream:readtrailer(timeout)
240246
if not trailer then
241247
return nil, reason
242248
end

0 commit comments

Comments
 (0)