Skip to content

Commit 2de4f11

Browse files
committed
refactor: allow retrieving DB instance from context
1 parent 08b2fb1 commit 2de4f11

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

db/postgres/postgres.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ func DeleteForced(ctx context.Context, object interface{}) {
119119

120120
func DBCtx(ctx context.Context) *gorm.DB {
121121
if ctx != nil {
122+
dbCtx := DBFromContext(ctx)
123+
if dbCtx != nil {
124+
return dbCtx
125+
}
126+
122127
return db.WithContext(ctx)
123128
}
124129

db/postgres/utils.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package postgres
2+
3+
import (
4+
"context"
5+
"gorm.io/gorm"
6+
)
7+
8+
type (
9+
ContextDB struct{}
10+
)
11+
12+
func ContextWithDB(ctx context.Context, db *gorm.DB) context.Context {
13+
return context.WithValue(ctx, ContextDB{}, db)
14+
}
15+
16+
func DBFromContext(ctx context.Context) *gorm.DB {
17+
value := ctx.Value(ContextDB{})
18+
19+
if value == nil {
20+
return nil
21+
}
22+
23+
return value.(*gorm.DB)
24+
}

db/statistics.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@ func RunAsyncStatisticLoop(ctx context.Context) {
3838
}
3939
}
4040

41-
updateTx := postgres.DBCtx(ctx).Begin()
42-
4341
for entityType, entityValue := range resultMap {
4442
for action, actionValue := range entityValue {
4543
for entityID, count := range actionValue {
44+
updateTx := postgres.DBCtx(ctx).Begin()
45+
ctxWithTx := postgres.ContextWithDB(ctx, updateTx)
4646
switch entityType {
4747
case "mod":
4848
switch action {
4949
case "view":
50-
mod := postgres.GetModByID(ctx, entityID)
50+
mod := postgres.GetModByID(ctxWithTx, entityID)
5151
if mod != nil {
5252
currentHotness := mod.Hotness
5353
if currentHotness > 4 {
@@ -60,7 +60,7 @@ func RunAsyncStatisticLoop(ctx context.Context) {
6060
case "version":
6161
switch action {
6262
case "download":
63-
version := postgres.GetVersion(ctx, entityID)
63+
version := postgres.GetVersion(ctxWithTx, entityID)
6464
if version != nil {
6565
currentHotness := version.Hotness
6666
if currentHotness > 4 {
@@ -71,13 +71,11 @@ func RunAsyncStatisticLoop(ctx context.Context) {
7171
}
7272
}
7373
}
74+
updateTx.Commit()
7475
}
7576
}
7677
}
7778

78-
updateTx.Commit()
79-
updateTx = postgres.DBCtx(ctx).Begin()
80-
8179
type Result struct {
8280
ModID string
8381
Hotness uint
@@ -89,7 +87,9 @@ func RunAsyncStatisticLoop(ctx context.Context) {
8987
postgres.DBCtx(ctx).Raw("SELECT mod_id, SUM(hotness) AS hotness, SUM(downloads) AS downloads FROM versions GROUP BY mod_id").Scan(&resultRows)
9088

9189
for _, row := range resultRows {
92-
mod := postgres.GetModByID(ctx, row.ModID)
90+
updateTx := postgres.DBCtx(ctx).Begin()
91+
ctxWithTx := postgres.ContextWithDB(ctx, updateTx)
92+
mod := postgres.GetModByID(ctxWithTx, row.ModID)
9393
if mod != nil {
9494
currentPopularity := mod.Popularity
9595
if currentPopularity > 4 {
@@ -101,10 +101,9 @@ func RunAsyncStatisticLoop(ctx context.Context) {
101101
Downloads: row.Downloads,
102102
})
103103
}
104+
updateTx.Commit()
104105
}
105106

106-
updateTx.Commit()
107-
108107
log.Ctx(ctx).Info().Msgf("Statistics Updated! Took %s", time.Since(start).String())
109108
time.Sleep(time.Minute)
110109
}

0 commit comments

Comments
 (0)