diff --git a/.gitignore b/.gitignore index eb5b1f3..02b4e5a 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,7 @@ cover.out /internal/cases/internal/integration/logs /internal/skill/internal/integration/logs /internal/feedback/internal/integration/logs +/internal/ai/internal/integration/logs /config/cert/ local_test.go diff --git a/go.sum b/go.sum index 8a2f8e0..86a1896 100644 --- a/go.sum +++ b/go.sum @@ -290,7 +290,6 @@ github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -785,7 +784,6 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1004,7 +1002,6 @@ golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= -golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/ai/internal/domain/llm.go b/internal/ai/internal/domain/llm.go index 9da5327..ea55f01 100644 --- a/internal/ai/internal/domain/llm.go +++ b/internal/ai/internal/domain/llm.go @@ -43,7 +43,7 @@ type LLMResponse struct { } type BizConfig struct { - Id int64 + Id int64 Biz string // 使用的模型 Model string diff --git a/internal/ai/internal/integration/llm_config_test.go b/internal/ai/internal/integration/llm_config_test.go new file mode 100644 index 0000000..4cfe596 --- /dev/null +++ b/internal/ai/internal/integration/llm_config_test.go @@ -0,0 +1,305 @@ +package integration + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/ecodeclub/ekit/iox" + "github.com/ecodeclub/ginx/session" + "github.com/ecodeclub/webook/internal/ai/internal/integration/startup" + "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" + "github.com/ecodeclub/webook/internal/ai/internal/web" + "github.com/ecodeclub/webook/internal/credit" + "github.com/ecodeclub/webook/internal/test" + testioc "github.com/ecodeclub/webook/internal/test/ioc" + "github.com/gin-gonic/gin" + "github.com/gotomicro/ego/core/econf" + "github.com/gotomicro/ego/server/egin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type ConfigSuite struct { + suite.Suite + db *gorm.DB + adminHandler *web.AdminHandler + server *egin.Component +} + +func (s *ConfigSuite) SetupSuite() { + db := testioc.InitDB() + s.db = db + err := dao.InitTables(db) + s.NoError(err) + // 先插入 BizConfig + mou, err := startup.InitModule(s.db, nil, &credit.Module{}) + require.NoError(s.T(), err) + s.adminHandler = mou.AdminHandler + econf.Set("server", map[string]any{"contextTimeout": "1s"}) + server := egin.Load("server").Build() + server.Use(func(ctx *gin.Context) { + ctx.Set("_session", session.NewMemorySession(session.Claims{ + Uid: 123, + })) + }) + s.adminHandler.RegisterRoutes(server.Engine) + s.server = server +} + +func (s *ConfigSuite) TestConfig_Save() { + testCases := []struct { + name string + config web.ConfigRequest + before func(t *testing.T) + after func(t *testing.T, id int64) + wantCode int + id int64 + }{ + { + name: "新增", + config: web.ConfigRequest{ + Config: web.Config{ + Biz: "test", + MaxInput: 10, + Model: "testModel", + Price: 100, + Temperature: 0.5, + TopP: 0.5, + SystemPrompt: "testPrompt", + PromptTemplate: "testTemplate", + KnowledgeId: "testKnowledgeId", + }, + }, + before: func(t *testing.T) { + + }, + wantCode: 200, + id: 1, + after: func(t *testing.T, id int64) { + var conf dao.BizConfig + err := s.db.WithContext(context.Background()). + Where("id = ?", id).First(&conf).Error + require.NoError(t, err) + s.assertBizConfig(dao.BizConfig{ + Id: 1, + Biz: "test", + MaxInput: 10, + Model: "testModel", + Price: 100, + Temperature: 0.5, + TopP: 0.5, + SystemPrompt: "testPrompt", + PromptTemplate: "testTemplate", + KnowledgeId: "testKnowledgeId", + }, conf) + }, + }, + { + name: "更新", + config: web.ConfigRequest{ + Config: web.Config{ + Id: 2, + Biz: "2_test", + MaxInput: 102, + Model: "2_testModel", + Price: 102, + Temperature: 2.5, + TopP: 2.5, + SystemPrompt: "testPrompt2", + PromptTemplate: "testTemplate2", + KnowledgeId: "testKnowledgeId2", + }, + }, + before: func(t *testing.T) { + err := s.db.WithContext(context.Background()). + Table("ai_biz_configs"). + Create(dao.BizConfig{ + Id: 2, + Biz: "test_2", + MaxInput: 100, + Model: "testModel", + Price: 100, + Temperature: 0.5, + TopP: 0.5, + SystemPrompt: "testPrompt", + PromptTemplate: "testTemplate", + KnowledgeId: "testKnowledgeId", + Ctime: 11, + Utime: 22, + }).Error + require.NoError(t, err) + }, + after: func(t *testing.T, id int64) { + var conf dao.BizConfig + err := s.db.WithContext(context.Background()). + Where("id = ?", id). + Model(&dao.BizConfig{}). + First(&conf).Error + require.NoError(t, err) + s.assertBizConfig(dao.BizConfig{ + Id: 2, + Biz: "2_test", + MaxInput: 102, + Model: "2_testModel", + Price: 102, + Temperature: 2.5, + TopP: 2.5, + SystemPrompt: "testPrompt2", + PromptTemplate: "testTemplate2", + KnowledgeId: "testKnowledgeId2", + }, conf) + }, + wantCode: 200, + id: 2, + }, + } + for _, tc := range testCases { + s.T().Run(tc.name, func(t *testing.T) { + tc.before(t) + req, err := http.NewRequest(http.MethodPost, + "/ai/config/save", iox.NewJSONReader(tc.config)) + req.Header.Set("content-type", "application/json") + require.NoError(t, err) + recorder := test.NewJSONResponseRecorder[int64]() + s.server.ServeHTTP(recorder, req) + require.Equal(t, tc.wantCode, recorder.Code) + id := recorder.MustScan().Data + assert.Equal(t, tc.id, id) + tc.after(t, id) + err = s.db.Exec("TRUNCATE TABLE `ai_biz_configs`").Error + require.NoError(s.T(), err) + }) + } +} + +func (s *ConfigSuite) TestConfig_List() { + configs := make([]dao.BizConfig, 0, 32) + for i := 1; i < 10; i++ { + cfg := dao.BizConfig{ + Id: int64(i), + Biz: fmt.Sprintf("biz_%d", i), + MaxInput: 100, + Model: fmt.Sprintf("test_model_%d", i), + Price: 1000, + Temperature: 37.5, + TopP: 0.8, + SystemPrompt: "test_prompt", + PromptTemplate: "test_template", + KnowledgeId: "test_knowledge", + } + configs = append(configs, cfg) + } + err := s.db.WithContext(context.Background()).Create(&configs).Error + require.NoError(s.T(), err) + req, err := http.NewRequest(http.MethodGet, + "/ai/config/list", iox.NewJSONReader(nil)) + req.Header.Set("content-type", "application/json") + require.NoError(s.T(), err) + recorder := test.NewJSONResponseRecorder[[]web.Config]() + s.server.ServeHTTP(recorder, req) + require.Equal(s.T(), 200, recorder.Code) + confs := recorder.MustScan().Data + assert.Equal(s.T(), getWantConfigs(), confs) + err = s.db.Exec("TRUNCATE TABLE `ai_biz_configs`").Error +} + +func (s *ConfigSuite) Test_Detail() { + testcases := []struct { + name string + req web.ConfigInfoReq + before func(t *testing.T) + wantCode int + wantData web.Config + }{ + { + name: "获取配置", + wantCode: 200, + req: web.ConfigInfoReq{ + Id: 3, + }, + before: func(t *testing.T) { + err := s.db.WithContext(context.Background()). + Table("ai_biz_configs"). + Create(dao.BizConfig{ + Id: 3, + Biz: "test_3", + MaxInput: 100, + Model: "testModel", + Price: 100, + Temperature: 0.5, + TopP: 0.5, + SystemPrompt: "testPrompt", + PromptTemplate: "testTemplate", + KnowledgeId: "testKnowledgeId", + Ctime: 11, + Utime: 22, + }).Error + require.NoError(t, err) + }, + wantData: web.Config{ + Id: 3, + Biz: "test_3", + MaxInput: 100, + Model: "testModel", + Price: 100, + Temperature: 0.5, + TopP: 0.5, + SystemPrompt: "testPrompt", + PromptTemplate: "testTemplate", + KnowledgeId: "testKnowledgeId", + }, + }, + } + for _, tc := range testcases { + s.T().Run(tc.name, func(t *testing.T) { + tc.before(t) + req, err := http.NewRequest(http.MethodPost, + "/ai/config/detail", iox.NewJSONReader(tc.req)) + req.Header.Set("content-type", "application/json") + require.NoError(t, err) + recorder := test.NewJSONResponseRecorder[web.Config]() + s.server.ServeHTTP(recorder, req) + require.Equal(s.T(), 200, recorder.Code) + conf := recorder.MustScan().Data + assert.Equal(t, tc.wantData, conf) + err = s.db.Exec("TRUNCATE TABLE `ai_biz_configs`").Error + require.NoError(s.T(), err) + }) + } +} + +func getWantConfigs() []web.Config { + configs := make([]web.Config, 0, 32) + for i := 9; i >= 1; i-- { + cfg := web.Config{ + Id: int64(i), + Biz: fmt.Sprintf("biz_%d", i), + MaxInput: 100, + Model: fmt.Sprintf("test_model_%d", i), + Price: 1000, + Temperature: 37.5, + TopP: 0.8, + SystemPrompt: "test_prompt", + PromptTemplate: "test_template", + KnowledgeId: "test_knowledge", + } + configs = append(configs, cfg) + } + return configs +} + +func (s *ConfigSuite) assertBizConfig(wantConfig dao.BizConfig, actualConfig dao.BizConfig) { + assert.True(s.T(), actualConfig.Ctime > 0) + assert.True(s.T(), actualConfig.Utime > 0) + actualConfig.Ctime = 0 + actualConfig.Utime = 0 + assert.Equal(s.T(), wantConfig, actualConfig) +} + +func TestConfigSuite(t *testing.T) { + suite.Run(t, new(ConfigSuite)) +} diff --git a/internal/ai/internal/integration/startup/wire_gen.go b/internal/ai/internal/integration/startup/wire_gen.go index e1f7a99..8adcb6e 100644 --- a/internal/ai/internal/integration/startup/wire_gen.go +++ b/internal/ai/internal/integration/startup/wire_gen.go @@ -7,6 +7,8 @@ package startup import ( + "sync" + "github.com/ecodeclub/webook/internal/ai" "github.com/ecodeclub/webook/internal/ai/internal/repository" "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" @@ -16,13 +18,12 @@ import ( "github.com/ecodeclub/webook/internal/ai/internal/service/llm/handler/config" credit2 "github.com/ecodeclub/webook/internal/ai/internal/service/llm/handler/credit" "github.com/ecodeclub/webook/internal/ai/internal/service/llm/handler/log" - "github.com/ecodeclub/webook/internal/ai/internal/service/llm/handler/mocks" + hdlmocks "github.com/ecodeclub/webook/internal/ai/internal/service/llm/handler/mocks" "github.com/ecodeclub/webook/internal/ai/internal/service/llm/handler/record" "github.com/ecodeclub/webook/internal/ai/internal/web" "github.com/ecodeclub/webook/internal/credit" "github.com/ego-component/egorm" "gorm.io/gorm" - "sync" ) // Injectors from wire.go: @@ -50,7 +51,7 @@ func InitModule(db *gorm.DB, hdl *hdlmocks.MockHandler, creditSvc *credit.Module module := &ai.Module{ Svc: llmService, Hdl: webHandler, - ADMINHandler: adminHandler, + AdminHandler: adminHandler, } return module, nil } diff --git a/internal/ai/internal/repository/dao/config.go b/internal/ai/internal/repository/dao/config.go index 359ec45..6bd858e 100644 --- a/internal/ai/internal/repository/dao/config.go +++ b/internal/ai/internal/repository/dao/config.go @@ -45,7 +45,10 @@ func (dao *GORMConfigDAO) GetById(ctx context.Context, id int64) (BizConfig, err func (dao *GORMConfigDAO) List(ctx context.Context) ([]BizConfig, error) { var configs []BizConfig - err := dao.db.WithContext(ctx).Find(&configs).Error + err := dao.db.WithContext(ctx). + Model(&BizConfig{}). + Order("id desc"). + Find(&configs).Error return configs, err } func (dao *GORMConfigDAO) GetConfig(ctx context.Context, biz string) (BizConfig, error) { diff --git a/internal/ai/module.go b/internal/ai/module.go index 3843dfa..941289c 100644 --- a/internal/ai/module.go +++ b/internal/ai/module.go @@ -3,5 +3,5 @@ package ai type Module struct { Svc LLMService Hdl *LLMHandler - ADMINHandler *ADMINHandler + AdminHandler *AdminHandler } diff --git a/internal/ai/type.go b/internal/ai/type.go index dd675e7..66a813b 100644 --- a/internal/ai/type.go +++ b/internal/ai/type.go @@ -9,5 +9,5 @@ import ( type LLMRequest = domain.LLMRequest type LLMResponse = domain.LLMResponse type LLMService = llm.Service -type ADMINHandler = web.AdminHandler +type AdminHandler = web.AdminHandler type LLMHandler = web.Handler diff --git a/internal/ai/wire_gen.go b/internal/ai/wire_gen.go index 3a942da..360f42e 100644 --- a/internal/ai/wire_gen.go +++ b/internal/ai/wire_gen.go @@ -7,6 +7,8 @@ package ai import ( + "sync" + "github.com/ecodeclub/webook/internal/ai/internal/repository" "github.com/ecodeclub/webook/internal/ai/internal/repository/dao" "github.com/ecodeclub/webook/internal/ai/internal/service" @@ -19,7 +21,6 @@ import ( "github.com/ecodeclub/webook/internal/credit" "github.com/ego-component/egorm" "gorm.io/gorm" - "sync" ) // Injectors from wire.go: @@ -48,7 +49,7 @@ func InitModule(db *gorm.DB, creditSvc *credit.Module) (*Module, error) { module := &Module{ Svc: llmService, Hdl: webHandler, - ADMINHandler: adminHandler, + AdminHandler: adminHandler, } return module, nil }