Skip to content

Commit e16d176

Browse files
Mak MufticMatija Petrunic
authored andcommitted
Move signature verification to middleware
1 parent b48838b commit e16d176

File tree

15 files changed

+115
-72
lines changed

15 files changed

+115
-72
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
package controllers
1+
package constants
22

33
const StatsSignedData = "loadbalancer-request"

internal/controllers/api.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,16 @@ type ApiController struct {
99
whitelistEnabled bool
1010
repositories repositories.Repos
1111
actions actions.Actions
12-
privateKey string
1312
}
1413

1514
func NewApiController(
1615
whitelistEnabled bool,
1716
repositories repositories.Repos,
1817
actions actions.Actions,
19-
privateKey string,
2018
) *ApiController {
2119
return &ApiController{
2220
whitelistEnabled: whitelistEnabled,
2321
repositories: repositories,
2422
actions: actions,
25-
privateKey: privateKey,
2623
}
2724
}

internal/controllers/metrics_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ func TestApiController_SaveMetricsHandler(t *testing.T) {
245245
MetricsRepo: &metricsRepoMock,
246246
RecordRepo: &recordRepoMock,
247247
DowntimeRepo: &downtimeRepoMock,
248-
}, nil, "")
248+
}, nil)
249249

250250
handler := http.HandlerFunc(apiController.SaveMetricsHandler)
251251

internal/controllers/ping_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func TestApiController_PingHandler(t *testing.T) {
106106
MetricsRepo: &metricsRepoMock,
107107
RecordRepo: &recordRepoMock,
108108
DowntimeRepo: &downtimeRepoMock,
109-
}, nil, "")
109+
}, nil)
110110
handler := http.HandlerFunc(apiController.PingHandler)
111111

112112
// create test request and populate context

internal/controllers/register_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ func TestApiController_RegisterHandler(t *testing.T) {
146146
MetricsRepo: &metricsRepoMock,
147147
RecordRepo: &recordRepoMock,
148148
DowntimeRepo: &downtimeRepoMock,
149-
}, nil, "")
149+
}, nil)
150150

151151
handler := http.HandlerFunc(apiController.RegisterHandler)
152152

internal/controllers/rpc_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func TestApiController_RPCHandler(t *testing.T) {
5656
apiController := NewApiController(false, repositories.Repos{
5757
NodeRepo: &nodeRepoMock,
5858
RecordRepo: &recordRepoMock,
59-
}, actionsMockObject, "")
59+
}, actionsMockObject)
6060

6161
handler := http.HandlerFunc(apiController.RPCHandler)
6262

@@ -165,7 +165,7 @@ func TestApiController_BatchRPCHandler(t *testing.T) {
165165
apiController := NewApiController(false, repositories.Repos{
166166
NodeRepo: &nodeRepoMock,
167167
RecordRepo: &recordRepoMock,
168-
}, actionsMockObject, "")
168+
}, actionsMockObject)
169169

170170
handler := http.HandlerFunc(apiController.RPCHandler)
171171

@@ -244,7 +244,7 @@ func TestApiController_RPCHandler_InvalidBody(t *testing.T) {
244244
actionsMockObject := new(actionMocks.Actions)
245245
actionsMockObject.On("PenalizeNode", mock.Anything, mock.Anything).Return()
246246

247-
apiController := NewApiController(false, repositories.Repos{}, actionsMockObject, "")
247+
apiController := NewApiController(false, repositories.Repos{}, actionsMockObject)
248248

249249
handler := http.HandlerFunc(apiController.RPCHandler)
250250

internal/controllers/stats.go

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ import (
1313
"github.com/NodeFactoryIo/vedran/internal/models"
1414
"github.com/NodeFactoryIo/vedran/internal/stats"
1515

16-
"github.com/NodeFactoryIo/go-substrate-rpc-client/signature"
17-
"github.com/ethereum/go-ethereum/common/hexutil"
1816
muxhelpper "github.com/gorilla/mux"
1917
log "github.com/sirupsen/logrus"
2018
)
@@ -50,19 +48,8 @@ type LoadbalancerStatsRequest struct {
5048
TotalReward string `json:"total_reward"`
5149
}
5250

53-
// handler for `POST /api/v1/stats`
51+
// handler for `POST /api/v1/stats` - signature verification in middleware
5452
func (c *ApiController) StatisticsHandlerAllStatsForLoadbalancer(w http.ResponseWriter, r *http.Request) {
55-
verified, httpStatusCode, err := verifySignatureInHeader(r, c.privateKey)
56-
if err != nil {
57-
http.Error(w, http.StatusText(httpStatusCode), httpStatusCode)
58-
return
59-
}
60-
if !verified {
61-
log.Errorf("Invalid request signature")
62-
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
63-
return
64-
}
65-
6653
totalRewardAsFloat, err := getTotalRewardFromRequest(r)
6754
if err != nil {
6855
log.Error(err)
@@ -110,25 +97,6 @@ func (c *ApiController) StatisticsHandlerAllStatsForLoadbalancer(w http.Response
11097
})
11198
}
11299

113-
func verifySignatureInHeader(r *http.Request, privateKey string) (bool, int, error) {
114-
sig := r.Header.Get("X-Signature")
115-
if sig == "" {
116-
log.Error("Missing signature header")
117-
return false, http.StatusBadRequest, nil
118-
}
119-
sigInBytes, err := hexutil.Decode(sig)
120-
if err != nil {
121-
log.Errorf("Unable to decode signature, because of: %v", err)
122-
return false, http.StatusBadRequest, err
123-
}
124-
verified, err := signature.Verify([]byte(StatsSignedData), sigInBytes, privateKey)
125-
if err != nil {
126-
log.Errorf("Failed to verify signature, because %v", err)
127-
return false, http.StatusInternalServerError, err
128-
}
129-
return verified, 0, err
130-
}
131-
132100
func getTotalRewardFromRequest(r *http.Request) (float64, error) {
133101
var statsRequest LoadbalancerStatsRequest
134102
reqBody, err := ioutil.ReadAll(r.Body)

internal/controllers/stats_test.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"fmt"
99
"github.com/NodeFactoryIo/go-substrate-rpc-client/signature"
1010
"github.com/NodeFactoryIo/vedran/internal/configuration"
11+
"github.com/NodeFactoryIo/vedran/internal/constants"
12+
"github.com/NodeFactoryIo/vedran/internal/middleware"
1113
"github.com/NodeFactoryIo/vedran/internal/models"
1214
"github.com/NodeFactoryIo/vedran/internal/repositories"
1315
mocks "github.com/NodeFactoryIo/vedran/mocks/repositories"
@@ -135,7 +137,7 @@ func TestApiController_StatisticsHandlerAllStats(t *testing.T) {
135137
RecordRepo: &recordRepoMock,
136138
DowntimeRepo: &downtimeRepoMock,
137139
PayoutRepo: &payoutRepoMock,
138-
}, nil, "")
140+
}, nil)
139141
handler := http.HandlerFunc(apiController.StatisticsHandlerAllStats)
140142
req, _ := http.NewRequest("GET", "/api/v1/stats", bytes.NewReader(nil))
141143
rr := httptest.NewRecorder()
@@ -225,7 +227,7 @@ func TestApiController_StatisticsHandlerAllStatsForLoadbalancer(t *testing.T) {
225227
requestContent: `{"total_reward":"1000000"}`,
226228
//
227229
secret: "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
228-
signatureData: StatsSignedData,
230+
signatureData: constants.StatsSignedData,
229231
},
230232
{
231233
name: "missing signature, 400 bad request",
@@ -260,7 +262,7 @@ func TestApiController_StatisticsHandlerAllStatsForLoadbalancer(t *testing.T) {
260262
nodeNumberOfPings: float64(8640),
261263
//
262264
secret: "",
263-
signatureData: StatsSignedData,
265+
signatureData: constants.StatsSignedData,
264266
},
265267
{
266268
name: "invalid signature, 400 bad request",
@@ -303,7 +305,7 @@ func TestApiController_StatisticsHandlerAllStatsForLoadbalancer(t *testing.T) {
303305
payoutRepoFindLatestPayoutError: errors.New("db-error"),
304306
secret: "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
305307
requestContent: `{"total_reward":"1000000"}`,
306-
signatureData: StatsSignedData,
308+
signatureData: constants.StatsSignedData,
307309
},
308310
}
309311
configuration.Config.Fee = 0.1
@@ -357,8 +359,12 @@ func TestApiController_StatisticsHandlerAllStatsForLoadbalancer(t *testing.T) {
357359
DowntimeRepo: &downtimeRepoMock,
358360
PayoutRepo: &payoutRepoMock,
359361
FeeRepo: &feeRepoMock,
360-
}, nil, test.secret)
361-
handler := http.HandlerFunc(apiController.StatisticsHandlerAllStatsForLoadbalancer)
362+
}, nil)
363+
364+
handler := middleware.VerifySignatureMiddleware(
365+
http.HandlerFunc(apiController.StatisticsHandlerAllStatsForLoadbalancer),
366+
test.secret,
367+
)
362368

363369
req, _ := http.NewRequest("POST", "/api/v1/stats", bytes.NewReader([]byte(test.requestContent)))
364370

@@ -486,7 +492,7 @@ func TestApiController_StatisticsHandlerStatsForNode(t *testing.T) {
486492
RecordRepo: &recordRepoMock,
487493
DowntimeRepo: &downtimeRepoMock,
488494
PayoutRepo: &payoutRepoMock,
489-
}, nil, "")
495+
}, nil)
490496
type ContextKey string
491497
req, _ := http.NewRequest("GET", "/api/v1/stats/node/1", bytes.NewReader(nil))
492498
req = req.WithContext(context.WithValue(req.Context(), ContextKey(test.contextKey), "1"))

internal/controllers/ws_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func TestApiController_WSHandler(t *testing.T) {
8686
apiController := NewApiController(false, repositories.Repos{
8787
NodeRepo: &nodeRepoMock,
8888
RecordRepo: &recordRepoMock,
89-
}, actionsMockObject, "")
89+
}, actionsMockObject)
9090

9191
// start test loadbalancer ws server
9292
router := mm.NewRouter()

internal/loadbalancer/server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ func StartLoadBalancerServer(
9393
// start server
9494
log.Infof("Starting vedran load balancer on port :%d...", props.Port)
9595
apiController := controllers.NewApiController(
96-
props.WhitelistEnabled, *repos, actions.NewActions(), privateKey,
96+
props.WhitelistEnabled, *repos, actions.NewActions(),
9797
)
98-
r := router.CreateNewApiRouter(apiController)
98+
r := router.CreateNewApiRouter(apiController, privateKey)
9999
prometheus.RecordMetrics(*repos)
100100
if props.CertFile != "" {
101101
err = http.ListenAndServeTLS(

0 commit comments

Comments
 (0)