Skip to content

Commit

Permalink
case模块增加同步到知识库的操作
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwenliang committed Dec 9, 2024
1 parent b237003 commit f2b7f6f
Show file tree
Hide file tree
Showing 18 changed files with 355 additions and 45 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ require (
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/cel-go v0.11.3 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gotomicro/logrotate v0.0.0-20211108034117-46d53eedc960 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
Expand Down
130 changes: 130 additions & 0 deletions internal/cases/internal/integration/knowledge_base_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package integration

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"testing"
"time"

"github.com/ecodeclub/ekit/iox"
"github.com/ecodeclub/ginx/session"
"github.com/ecodeclub/webook/internal/ai"
aimocks "github.com/ecodeclub/webook/internal/ai/mocks"
"github.com/ecodeclub/webook/internal/cases/internal/domain"
eveMocks "github.com/ecodeclub/webook/internal/cases/internal/event/mocks"
"github.com/ecodeclub/webook/internal/cases/internal/integration/startup"
"github.com/ecodeclub/webook/internal/cases/internal/repository/dao"
"github.com/ecodeclub/webook/internal/cases/internal/service"
"github.com/ecodeclub/webook/internal/interactive"
intrmocks "github.com/ecodeclub/webook/internal/interactive/mocks"
"github.com/ecodeclub/webook/internal/test"
testioc "github.com/ecodeclub/webook/internal/test/ioc"
"github.com/ego-component/egorm"
"github.com/gin-gonic/gin"
"github.com/gotomicro/ego/core/econf"
"github.com/gotomicro/ego/server/egin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
)

type KnowledgeBaseTestSuite struct {
suite.Suite
db *egorm.Component
caseSvc service.Service
}

func (k *KnowledgeBaseTestSuite) SetupSuite() {
k.db = testioc.InitDB()
}

func (k *KnowledgeBaseTestSuite) getWantCase(id int64) domain.Case {
que := domain.Case{
Id: id,
Uid: 123,
Labels: []string{"label"},
Introduction: fmt.Sprintf("intro %d", id),
Title: fmt.Sprintf("标题%d", id),
Content: fmt.Sprintf("内容%d", id),
GiteeRepo: "gitee",
GithubRepo: "github",
Shorthand: "速记",
Keywords: fmt.Sprintf("关键字 %d", id),
Highlight: fmt.Sprintf("亮点 %d", id),
Guidance: fmt.Sprintf("引导点 %d", id),
Biz: domain.BizCase,
BizId: id,
Status: domain.PublishedStatus,
}
return que
}

func (k *KnowledgeBaseTestSuite) TestKnowledgeBaseSync() {
ctrl := gomock.NewController(k.T())
svc := aimocks.NewMockRepositoryBaseSvc(ctrl)
producer := eveMocks.NewMockSyncEventProducer(ctrl)
producer.EXPECT().Produce(gomock.Any(), gomock.Any()).AnyTimes()
svc.EXPECT().UploadFile(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, file ai.KnowledgeBaseFile) error {
assert.Equal(k.T(), fmt.Sprintf("case_%d", file.BizID), file.Name)
assert.Equal(k.T(), domain.BizCase, file.Biz)
wantCase := k.getWantCase(file.BizID)
var actualCa domain.Case
err := json.Unmarshal(file.Data, &actualCa)
if err != nil {
return err
}
assert.Equal(k.T(), file.BizID, wantCase.Id)
actualCa.Ctime = wantCase.Ctime
actualCa.Utime = wantCase.Utime
assert.Equal(k.T(), wantCase, actualCa)
return nil
}).AnyTimes()
// 初始化数据
err := dao.InitTables(k.db)
require.NoError(k.T(), err)

require.NoError(k.T(), err)

intrSvc := intrmocks.NewMockService(ctrl)
intrModule := &interactive.Module{
Svc: intrSvc,
}
module, err := startup.InitModule(producer, &ai.Module{
KnowledgeBaseSvc: svc,
}, intrModule)
require.NoError(k.T(), err)
k.caseSvc = module.Svc
wantCa := k.getWantCase(1)
_, err = k.caseSvc.Publish(context.Background(), wantCa)
require.NoError(k.T(), err)

econf.Set("server", map[string]any{"contextTimeout": "1s"})
server := egin.Load("server").Build()
server.Use(func(ctx *gin.Context) {
ctx.Set("_session", session.NewMemorySession(session.Claims{
Uid: 123,
Data: map[string]string{
"creator": "true",
"memberDDL": strconv.FormatInt(time.Now().Add(time.Hour).UnixMilli(), 10),
},
}))
})
module.KnowledgeBaseHandler.PrivateRoutes(server.Engine)

req, err := http.NewRequest(http.MethodGet,
"/case/knowledgeBase/syncAll", iox.NewJSONReader(nil))
req.Header.Set("content-type", "application/json")
require.NoError(k.T(), err)
recorder := test.NewJSONResponseRecorder[int64]()
server.ServeHTTP(recorder, req)
// 等待同步完成
time.Sleep(3 * time.Second)
}

func TestKnowledgeBaseTestSuite(t *testing.T) {
suite.Run(t, new(KnowledgeBaseTestSuite))
}
14 changes: 11 additions & 3 deletions internal/cases/internal/integration/startup/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ func InitModule(
service.NewService,
service.NewCaseSetService,
service.NewLLMExamineService,
initKnowledgeBaseSvc,
web.NewHandler,
web.NewAdminCaseSetHandler,
web.NewAdminCaseHandler,
web.NewKnowledgeBaseHandler,
wire.FieldsOf(new(*interactive.Module), "Svc"),
wire.FieldsOf(new(*ai.Module), "Svc"),
wire.Struct(new(cases.Module), "AdminHandler", "ExamineSvc", "Svc", "Hdl", "AdminSetHandler"),
wire.FieldsOf(new(*ai.Module), "Svc", "KnowledgeBaseSvc"),
wire.Struct(new(cases.Module), "AdminHandler", "ExamineSvc", "Svc", "Hdl", "AdminSetHandler", "KnowledgeBaseHandler"),
)
return new(cases.Module), nil
}
Expand All @@ -70,14 +72,20 @@ func InitExamModule(
service.NewCaseSetService,
service.NewService,
service.NewLLMExamineService,
initKnowledgeBaseSvc,
web.NewHandler,
web.NewAdminCaseSetHandler,
web.NewAdminCaseHandler,
web.NewExamineHandler,
web.NewCaseSetHandler,
web.NewKnowledgeBaseHandler,
wire.FieldsOf(new(*interactive.Module), "Svc"),
wire.FieldsOf(new(*ai.Module), "Svc"),
wire.FieldsOf(new(*ai.Module), "Svc", "KnowledgeBaseSvc"),
wire.Struct(new(cases.Module), "*"),
)
return new(cases.Module), nil
}

func initKnowledgeBaseSvc(svc ai.KnowledgeBaseService, caRepo repository.CaseRepo) service.KnowledgeBaseService {
return service.NewKnowledgeBaseService(caRepo, svc, "knowledge_id")
}
42 changes: 28 additions & 14 deletions internal/cases/internal/integration/startup/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions internal/cases/internal/repository/cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type CaseRepo interface {
// Sync 保存到制作库,而后同步到线上库
Sync(ctx context.Context, ca domain.Case) (int64, error)
// 管理端接口
Ids(ctx context.Context) ([]int64, error)
List(ctx context.Context, offset int, limit int) ([]domain.Case, error)
Total(ctx context.Context) (int64, error)
Save(ctx context.Context, ca domain.Case) (int64, error)
Expand All @@ -34,6 +35,10 @@ type caseRepo struct {
caseDao dao.CaseDAO
}

func (c *caseRepo) Ids(ctx context.Context) ([]int64, error) {
return c.caseDao.Ids(ctx)
}

func (c *caseRepo) Exclude(ctx context.Context, ids []int64, offset int, limit int) ([]domain.Case, int64, error) {
var (
eg errgroup.Group
Expand Down
15 changes: 14 additions & 1 deletion internal/cases/internal/repository/dao/cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"time"

"github.com/ecodeclub/webook/internal/cases/internal/domain"

"gorm.io/gorm/clause"

"github.com/ego-component/egorm"
Expand All @@ -18,7 +20,8 @@ type CaseDAO interface {
Count(ctx context.Context) (int64, error)

Sync(ctx context.Context, c Case) (int64, error)

// 提供给同步到知识库用
Ids(ctx context.Context) ([]int64, error)
// 线上库
PublishCaseList(ctx context.Context, offset, limit int) ([]PublishCase, error)
PublishCaseCount(ctx context.Context) (int64, error)
Expand All @@ -35,6 +38,16 @@ type caseDAO struct {
updateColumns []string
}

func (ca *caseDAO) Ids(ctx context.Context) ([]int64, error) {
var ids []int64
err := ca.db.WithContext(ctx).
Select("id").
Model(&Case{}).
Where("status = ?", domain.PublishedStatus).
Scan(&ids).Error
return ids, err
}

func (ca *caseDAO) NotInTotal(ctx context.Context, ids []int64) (int64, error) {
var res int64
err := ca.db.WithContext(ctx).
Expand Down
76 changes: 76 additions & 0 deletions internal/cases/internal/service/knowlege_base.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package service

import (
"context"
"encoding/json"
"fmt"
"time"

"github.com/ecodeclub/webook/internal/ai"
"github.com/ecodeclub/webook/internal/cases/internal/domain"
"github.com/ecodeclub/webook/internal/cases/internal/repository"
"github.com/gotomicro/ego/core/elog"
)

type KnowledgeBaseService interface {
FullSync()
}

type knowledgeBaseSvc struct {
caseRepo repository.CaseRepo
knowledgeBaseSvc ai.KnowledgeBaseService
logger *elog.Component
knowledgeBaseId string
}

func NewKnowledgeBaseService(repo repository.CaseRepo, svc ai.KnowledgeBaseService, knowledgeBaseId string) KnowledgeBaseService {
return &knowledgeBaseSvc{
caseRepo: repo,
knowledgeBaseSvc: svc,
logger: elog.DefaultLogger,
knowledgeBaseId: knowledgeBaseId,
}
}

func (k *knowledgeBaseSvc) FullSync() {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
cids, err := k.caseRepo.Ids(ctx)
cancel()
if err != nil {
k.logger.Error("查找案例列表失败", elog.FieldErr(err))
return
}
for _, cid := range cids {
err = k.syncOne(cid)
if err != nil {
k.logger.Error(fmt.Sprintf("同步案例 %d失败", cid), elog.FieldErr(err))
} else {
k.logger.Info(fmt.Sprintf("同步案例 %d成功", cid))
}
}
}

func (k *knowledgeBaseSvc) syncOne(id int64) error {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
ca, err := k.caseRepo.GetById(ctx, id)
if err != nil {
return fmt.Errorf("获取案例id列表失败 %w", err)
}
data, err := json.Marshal(ca)
if err != nil {
return fmt.Errorf("序列化案例数据失败 %w", err)
}
err = k.knowledgeBaseSvc.UploadFile(ctx, ai.KnowledgeBaseFile{
Biz: domain.BizCase,
BizID: ca.Id,
Name: fmt.Sprintf("case_%d", ca.Id),
Data: data,
Type: ai.RepositoryBaseTypeRetrieval,
KnowledgeBaseID: k.knowledgeBaseId,
})
if err != nil {
return fmt.Errorf("上传到ai的知识库失败 %w", err)
}
return err
}
Loading

0 comments on commit f2b7f6f

Please sign in to comment.