Skip to content

Commit

Permalink
添加llm配置的测试
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwenliang committed Nov 28, 2024
1 parent e3b3be3 commit f11f468
Show file tree
Hide file tree
Showing 9 changed files with 320 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ cover.out
/internal/cases/internal/integration/logs
/internal/skill/internal/integration/logs
/internal/feedback/internal/integration/logs
/internal/ai/internal/integration/logs
/config/cert/

local_test.go
Expand Down
3 changes: 0 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hf
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y=
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
Expand Down Expand Up @@ -785,7 +784,6 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
Expand Down Expand Up @@ -1004,7 +1002,6 @@ golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc=
golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand Down
2 changes: 1 addition & 1 deletion internal/ai/internal/domain/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type LLMResponse struct {
}

type BizConfig struct {
Id int64
Id int64
Biz string
// 使用的模型
Model string
Expand Down
305 changes: 305 additions & 0 deletions internal/ai/internal/integration/llm_config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
package integration

import (
"context"
"fmt"
"net/http"
"testing"

"github.com/ecodeclub/ekit/iox"
"github.com/ecodeclub/ginx/session"
"github.com/ecodeclub/webook/internal/ai/internal/integration/startup"
"github.com/ecodeclub/webook/internal/ai/internal/repository/dao"
"github.com/ecodeclub/webook/internal/ai/internal/web"
"github.com/ecodeclub/webook/internal/credit"
"github.com/ecodeclub/webook/internal/test"
testioc "github.com/ecodeclub/webook/internal/test/ioc"
"github.com/gin-gonic/gin"
"github.com/gotomicro/ego/core/econf"
"github.com/gotomicro/ego/server/egin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)

type ConfigSuite struct {
suite.Suite
db *gorm.DB
adminHandler *web.AdminHandler
server *egin.Component
}

func (s *ConfigSuite) SetupSuite() {
db := testioc.InitDB()
s.db = db
err := dao.InitTables(db)
s.NoError(err)
// 先插入 BizConfig
mou, err := startup.InitModule(s.db, nil, &credit.Module{})
require.NoError(s.T(), err)
s.adminHandler = mou.AdminHandler
econf.Set("server", map[string]any{"contextTimeout": "1s"})
server := egin.Load("server").Build()
server.Use(func(ctx *gin.Context) {
ctx.Set("_session", session.NewMemorySession(session.Claims{
Uid: 123,
}))
})
s.adminHandler.RegisterRoutes(server.Engine)
s.server = server
}

func (s *ConfigSuite) TestConfig_Save() {
testCases := []struct {
name string
config web.ConfigRequest
before func(t *testing.T)
after func(t *testing.T, id int64)
wantCode int
id int64
}{
{
name: "新增",
config: web.ConfigRequest{
Config: web.Config{
Biz: "test",
MaxInput: 10,
Model: "testModel",
Price: 100,
Temperature: 0.5,
TopP: 0.5,
SystemPrompt: "testPrompt",
PromptTemplate: "testTemplate",
KnowledgeId: "testKnowledgeId",
},
},
before: func(t *testing.T) {

},
wantCode: 200,
id: 1,
after: func(t *testing.T, id int64) {
var conf dao.BizConfig
err := s.db.WithContext(context.Background()).
Where("id = ?", id).First(&conf).Error
require.NoError(t, err)
s.assertBizConfig(dao.BizConfig{
Id: 1,
Biz: "test",
MaxInput: 10,
Model: "testModel",
Price: 100,
Temperature: 0.5,
TopP: 0.5,
SystemPrompt: "testPrompt",
PromptTemplate: "testTemplate",
KnowledgeId: "testKnowledgeId",
}, conf)
},
},
{
name: "更新",
config: web.ConfigRequest{
Config: web.Config{
Id: 2,
Biz: "2_test",
MaxInput: 102,
Model: "2_testModel",
Price: 102,
Temperature: 2.5,
TopP: 2.5,
SystemPrompt: "testPrompt2",
PromptTemplate: "testTemplate2",
KnowledgeId: "testKnowledgeId2",
},
},
before: func(t *testing.T) {
err := s.db.WithContext(context.Background()).
Table("ai_biz_configs").
Create(dao.BizConfig{
Id: 2,
Biz: "test_2",
MaxInput: 100,
Model: "testModel",
Price: 100,
Temperature: 0.5,
TopP: 0.5,
SystemPrompt: "testPrompt",
PromptTemplate: "testTemplate",
KnowledgeId: "testKnowledgeId",
Ctime: 11,
Utime: 22,
}).Error
require.NoError(t, err)
},
after: func(t *testing.T, id int64) {
var conf dao.BizConfig
err := s.db.WithContext(context.Background()).
Where("id = ?", id).
Model(&dao.BizConfig{}).
First(&conf).Error
require.NoError(t, err)
s.assertBizConfig(dao.BizConfig{
Id: 2,
Biz: "2_test",
MaxInput: 102,
Model: "2_testModel",
Price: 102,
Temperature: 2.5,
TopP: 2.5,
SystemPrompt: "testPrompt2",
PromptTemplate: "testTemplate2",
KnowledgeId: "testKnowledgeId2",
}, conf)
},
wantCode: 200,
id: 2,
},
}
for _, tc := range testCases {
s.T().Run(tc.name, func(t *testing.T) {
tc.before(t)
req, err := http.NewRequest(http.MethodPost,
"/ai/config/save", iox.NewJSONReader(tc.config))
req.Header.Set("content-type", "application/json")
require.NoError(t, err)
recorder := test.NewJSONResponseRecorder[int64]()
s.server.ServeHTTP(recorder, req)
require.Equal(t, tc.wantCode, recorder.Code)
id := recorder.MustScan().Data
assert.Equal(t, tc.id, id)
tc.after(t, id)
err = s.db.Exec("TRUNCATE TABLE `ai_biz_configs`").Error
require.NoError(s.T(), err)
})
}
}

func (s *ConfigSuite) TestConfig_List() {
configs := make([]dao.BizConfig, 0, 32)
for i := 1; i < 10; i++ {
cfg := dao.BizConfig{
Id: int64(i),
Biz: fmt.Sprintf("biz_%d", i),
MaxInput: 100,
Model: fmt.Sprintf("test_model_%d", i),
Price: 1000,
Temperature: 37.5,
TopP: 0.8,
SystemPrompt: "test_prompt",
PromptTemplate: "test_template",
KnowledgeId: "test_knowledge",
}
configs = append(configs, cfg)
}
err := s.db.WithContext(context.Background()).Create(&configs).Error
require.NoError(s.T(), err)
req, err := http.NewRequest(http.MethodGet,
"/ai/config/list", iox.NewJSONReader(nil))
req.Header.Set("content-type", "application/json")
require.NoError(s.T(), err)
recorder := test.NewJSONResponseRecorder[[]web.Config]()
s.server.ServeHTTP(recorder, req)
require.Equal(s.T(), 200, recorder.Code)
confs := recorder.MustScan().Data
assert.Equal(s.T(), getWantConfigs(), confs)
err = s.db.Exec("TRUNCATE TABLE `ai_biz_configs`").Error
}

func (s *ConfigSuite) Test_Detail() {
testcases := []struct {
name string
req web.ConfigInfoReq
before func(t *testing.T)
wantCode int
wantData web.Config
}{
{
name: "获取配置",
wantCode: 200,
req: web.ConfigInfoReq{
Id: 3,
},
before: func(t *testing.T) {
err := s.db.WithContext(context.Background()).
Table("ai_biz_configs").
Create(dao.BizConfig{
Id: 3,
Biz: "test_3",
MaxInput: 100,
Model: "testModel",
Price: 100,
Temperature: 0.5,
TopP: 0.5,
SystemPrompt: "testPrompt",
PromptTemplate: "testTemplate",
KnowledgeId: "testKnowledgeId",
Ctime: 11,
Utime: 22,
}).Error
require.NoError(t, err)
},
wantData: web.Config{
Id: 3,
Biz: "test_3",
MaxInput: 100,
Model: "testModel",
Price: 100,
Temperature: 0.5,
TopP: 0.5,
SystemPrompt: "testPrompt",
PromptTemplate: "testTemplate",
KnowledgeId: "testKnowledgeId",
},
},
}
for _, tc := range testcases {
s.T().Run(tc.name, func(t *testing.T) {
tc.before(t)
req, err := http.NewRequest(http.MethodPost,
"/ai/config/detail", iox.NewJSONReader(tc.req))
req.Header.Set("content-type", "application/json")
require.NoError(t, err)
recorder := test.NewJSONResponseRecorder[web.Config]()
s.server.ServeHTTP(recorder, req)
require.Equal(s.T(), 200, recorder.Code)
conf := recorder.MustScan().Data
assert.Equal(t, tc.wantData, conf)
err = s.db.Exec("TRUNCATE TABLE `ai_biz_configs`").Error
require.NoError(s.T(), err)
})
}
}

func getWantConfigs() []web.Config {
configs := make([]web.Config, 0, 32)
for i := 9; i >= 1; i-- {
cfg := web.Config{
Id: int64(i),
Biz: fmt.Sprintf("biz_%d", i),
MaxInput: 100,
Model: fmt.Sprintf("test_model_%d", i),
Price: 1000,
Temperature: 37.5,
TopP: 0.8,
SystemPrompt: "test_prompt",
PromptTemplate: "test_template",
KnowledgeId: "test_knowledge",
}
configs = append(configs, cfg)
}
return configs
}

func (s *ConfigSuite) assertBizConfig(wantConfig dao.BizConfig, actualConfig dao.BizConfig) {
assert.True(s.T(), actualConfig.Ctime > 0)
assert.True(s.T(), actualConfig.Utime > 0)
actualConfig.Ctime = 0
actualConfig.Utime = 0
assert.Equal(s.T(), wantConfig, actualConfig)
}

func TestConfigSuite(t *testing.T) {
suite.Run(t, new(ConfigSuite))
}
7 changes: 4 additions & 3 deletions internal/ai/internal/integration/startup/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion internal/ai/internal/repository/dao/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ func (dao *GORMConfigDAO) GetById(ctx context.Context, id int64) (BizConfig, err

func (dao *GORMConfigDAO) List(ctx context.Context) ([]BizConfig, error) {
var configs []BizConfig
err := dao.db.WithContext(ctx).Find(&configs).Error
err := dao.db.WithContext(ctx).
Model(&BizConfig{}).
Order("id desc").
Find(&configs).Error
return configs, err
}
func (dao *GORMConfigDAO) GetConfig(ctx context.Context, biz string) (BizConfig, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/ai/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ package ai
type Module struct {
Svc LLMService
Hdl *LLMHandler
ADMINHandler *ADMINHandler
AdminHandler *AdminHandler
}
Loading

0 comments on commit f11f468

Please sign in to comment.