diff --git a/internal/datacoord/import_job.go b/internal/datacoord/import_job.go index 08f5503d6875e..3feea7144f028 100644 --- a/internal/datacoord/import_job.go +++ b/internal/datacoord/import_job.go @@ -39,6 +39,12 @@ func WithCollectionID(collectionID int64) ImportJobFilter { } } +func WithDbID(DbID int64) ImportJobFilter { + return func(job ImportJob) bool { + return job.GetDbID() == DbID + } +} + func WithJobStates(states ...internalpb.ImportJobState) ImportJobFilter { return func(job ImportJob) bool { for _, state := range states { @@ -100,6 +106,7 @@ func UpdateJobCompleteTime(completeTime string) UpdateJobAction { type ImportJob interface { GetJobID() int64 + GetDbID() int64 GetCollectionID() int64 GetCollectionName() string GetPartitionIDs() []int64 diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 68cd9fe617447..64983d7b729b7 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -1676,7 +1676,9 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter Status: merr.Success(), } - log := log.With(zap.Int64("collection", in.GetCollectionID()), + log := log.With( + zap.Int64("dbID", in.GetDbID()), + zap.Int64("collection", in.GetCollectionID()), zap.Int64s("partitions", in.GetPartitionIDs()), zap.Strings("channels", in.GetChannelNames())) log.Info("receive import request", zap.Any("files", in.GetFiles())) @@ -1742,6 +1744,7 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter job := &importJob{ ImportJob: &datapb.ImportJob{ JobID: idStart, + DbID: in.GetDbID(), CollectionID: in.GetCollectionID(), CollectionName: in.GetCollectionName(), PartitionIDs: in.GetPartitionIDs(), @@ -1768,7 +1771,7 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter } func (s *Server) GetImportProgress(ctx context.Context, in *internalpb.GetImportProgressRequest) (*internalpb.GetImportProgressResponse, error) { - log := log.With(zap.String("jobID", in.GetJobID())) + log := log.With(zap.String("jobID", in.GetJobID()), zap.Int64("dbID", in.GetDbID())) if err := merr.CheckHealthy(s.GetStateCode()); err != nil { return &internalpb.GetImportProgressResponse{ Status: merr.Status(err), @@ -1788,6 +1791,10 @@ func (s *Server) GetImportProgress(ctx context.Context, in *internalpb.GetImport resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("import job does not exist, jobID=%d", jobID))) return resp, nil } + if job.GetDbID() != 0 && job.GetDbID() != in.GetDbID() { + resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("import job does not exist, jobID=%d, dbID=%d", jobID, in.GetDbID()))) + return resp, nil + } progress, state, importedRows, totalRows, reason := GetJobProgress(jobID, s.importMeta, s.meta, s.jobManager) resp.State = state resp.Reason = reason @@ -1818,11 +1825,14 @@ func (s *Server) ListImports(ctx context.Context, req *internalpb.ListImportsReq } var jobs []ImportJob + filters := make([]ImportJobFilter, 0) + if req.GetDbID() != 0 { + filters = append(filters, WithDbID(req.GetDbID())) + } if req.GetCollectionID() != 0 { - jobs = s.importMeta.GetJobBy(WithCollectionID(req.GetCollectionID())) - } else { - jobs = s.importMeta.GetJobBy() + filters = append(filters, WithCollectionID(req.GetCollectionID())) } + jobs = s.importMeta.GetJobBy(filters...) for _, job := range jobs { progress, state, _, _, reason := GetJobProgress(job.GetJobID(), s.importMeta, s.meta, s.jobManager) @@ -1832,5 +1842,7 @@ func (s *Server) ListImports(ctx context.Context, req *internalpb.ListImportsReq resp.Progresses = append(resp.Progresses, progress) resp.CollectionNames = append(resp.CollectionNames, job.GetCollectionName()) } + log.Info("ListImports done", zap.Int64("collectionID", req.GetCollectionID()), + zap.Int64("dbID", req.GetDbID()), zap.Any("resp", resp)) return resp, nil } diff --git a/internal/datacoord/services_test.go b/internal/datacoord/services_test.go index 2ce5478fb1ec9..23c620181c58c 100644 --- a/internal/datacoord/services_test.go +++ b/internal/datacoord/services_test.go @@ -1412,9 +1412,10 @@ func TestImportV2(t *testing.T) { assert.NoError(t, err) assert.True(t, errors.Is(merr.Error(resp.GetStatus()), merr.ErrImportFailed)) - // normal case + // db does not exist var job ImportJob = &importJob{ ImportJob: &datapb.ImportJob{ + DbID: 1, JobID: 0, Schema: &schemapb.CollectionSchema{}, State: internalpb.ImportJobState_Failed, @@ -1423,12 +1424,31 @@ func TestImportV2(t *testing.T) { err = s.importMeta.AddJob(job) assert.NoError(t, err) resp, err = s.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{ + DbID: 2, + JobID: "0", + }) + assert.NoError(t, err) + assert.True(t, errors.Is(merr.Error(resp.GetStatus()), merr.ErrImportFailed)) + + // normal case + job = &importJob{ + ImportJob: &datapb.ImportJob{ + DbID: 1, + JobID: 0, + Schema: &schemapb.CollectionSchema{}, + State: internalpb.ImportJobState_Pending, + }, + } + err = s.importMeta.AddJob(job) + assert.NoError(t, err) + resp, err = s.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{ + DbID: 1, JobID: "0", }) assert.NoError(t, err) assert.Equal(t, int32(0), resp.GetStatus().GetCode()) - assert.Equal(t, int64(0), resp.GetProgress()) - assert.Equal(t, internalpb.ImportJobState_Failed, resp.GetState()) + assert.Equal(t, int64(10), resp.GetProgress()) + assert.Equal(t, internalpb.ImportJobState_Pending, resp.GetState()) }) t.Run("ListImports", func(t *testing.T) { @@ -1451,6 +1471,7 @@ func TestImportV2(t *testing.T) { assert.NoError(t, err) var job ImportJob = &importJob{ ImportJob: &datapb.ImportJob{ + DbID: 2, JobID: 0, CollectionID: 1, Schema: &schemapb.CollectionSchema{}, @@ -1467,7 +1488,20 @@ func TestImportV2(t *testing.T) { } err = s.importMeta.AddTask(task) assert.NoError(t, err) + // db id not match + resp, err = s.ListImports(ctx, &internalpb.ListImportsRequestInternal{ + DbID: 3, + CollectionID: 1, + }) + assert.NoError(t, err) + assert.Equal(t, int32(0), resp.GetStatus().GetCode()) + assert.Equal(t, 0, len(resp.GetJobIDs())) + assert.Equal(t, 0, len(resp.GetStates())) + assert.Equal(t, 0, len(resp.GetReasons())) + assert.Equal(t, 0, len(resp.GetProgresses())) + // db id match resp, err = s.ListImports(ctx, &internalpb.ListImportsRequestInternal{ + DbID: 2, CollectionID: 1, }) assert.NoError(t, err) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index c57ee26c10c26..789c02faf249e 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -139,8 +139,8 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { router.POST(ImportJobCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listImportJob))))) router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createImportJob))))) - router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) - router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) + router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &GetImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) + router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &GetImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) } type ( diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 31bf6f0d585b6..0ea2cc42a36c0 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -94,11 +94,14 @@ func (req *ImportReq) GetOptions() map[string]string { return req.Options } -type JobIDReq struct { - JobID string `json:"jobId" binding:"required"` +type GetImportReq struct { + DbName string `json:"dbName"` + JobID string `json:"jobId" binding:"required"` } -func (req *JobIDReq) GetJobID() string { return req.JobID } +func (req *GetImportReq) GetJobID() string { return req.JobID } + +func (req *GetImportReq) GetDbName() string { return req.DbName } type QueryReqV2 struct { DbName string `json:"dbName"` diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index d882ff007a478..148cf1c3899cc 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -345,6 +345,7 @@ message ImportResponse { message GetImportProgressRequest { string db_name = 1; string jobID = 2; + int64 dbID = 3; } message ImportTaskProgress { diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index ed7ab17641a81..0ab12b12280d0 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -6242,6 +6242,7 @@ func (node *Proxy) ImportV2(ctx context.Context, req *internalpb.ImportRequest) return &internalpb.ImportResponse{Status: merr.Status(err)}, nil } log := log.Ctx(ctx).With( + zap.String("dbName", req.GetDbName()), zap.String("collectionName", req.GetCollectionName()), zap.String("partition name", req.GetPartitionName()), zap.Any("files", req.GetFiles()), @@ -6267,6 +6268,11 @@ func (node *Proxy) ImportV2(ctx context.Context, req *internalpb.ImportRequest) } }() + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, req.GetDbName()) + if err != nil { + resp.Status = merr.Status(err) + return resp, nil + } collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { resp.Status = merr.Status(err) @@ -6377,6 +6383,7 @@ func (node *Proxy) ImportV2(ctx context.Context, req *internalpb.ImportRequest) } } importRequest := &internalpb.ImportRequestInternal{ + DbID: dbInfo.dbID, CollectionID: collectionID, CollectionName: req.GetCollectionName(), PartitionIDs: partitionIDs, @@ -6401,14 +6408,28 @@ func (node *Proxy) GetImportProgress(ctx context.Context, req *internalpb.GetImp }, nil } log := log.Ctx(ctx).With( + zap.String("dbName", req.GetDbName()), zap.String("jobID", req.GetJobID()), ) + + resp := &internalpb.GetImportProgressResponse{ + Status: merr.Success(), + } + method := "GetImportProgress" tr := timerecord.NewTimeRecorder(method) log.Info(rpcReceived(method)) + // Fill db id for datacoord. + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, req.GetDbName()) + if err != nil { + resp.Status = merr.Status(err) + return resp, nil + } + req.DbID = dbInfo.dbID + nodeID := fmt.Sprint(paramtable.GetNodeID()) - resp, err := node.dataCoord.GetImportProgress(ctx, req) + resp, err = node.dataCoord.GetImportProgress(ctx, req) if resp.GetStatus().GetCode() != 0 || err != nil { log.Warn("get import progress failed", zap.String("reason", resp.GetStatus().GetReason()), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(nodeID, method, metrics.FailLabel, req.GetDbName(), "").Inc() @@ -6445,6 +6466,11 @@ func (node *Proxy) ListImports(ctx context.Context, req *internalpb.ListImportsR err error collectionID UniqueID ) + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, req.GetDbName()) + if err != nil { + resp.Status = merr.Status(err) + return resp, nil + } if req.GetCollectionName() != "" { collectionID, err = globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { @@ -6453,7 +6479,9 @@ func (node *Proxy) ListImports(ctx context.Context, req *internalpb.ListImportsR return resp, nil } } + resp, err = node.dataCoord.ListImports(ctx, &internalpb.ListImportsRequestInternal{ + DbID: dbInfo.dbID, CollectionID: collectionID, }) if resp.GetStatus().GetCode() != 0 || err != nil { diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index 70171fe3b84cd..b2ab3cde1a2d4 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -1616,8 +1616,17 @@ func TestProxy_ImportV2(t *testing.T) { assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) node.UpdateStateCode(commonpb.StateCode_Healthy) - // no such collection + // no such database mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, mockErr) + globalMetaCache = mc + rsp, err = node.ImportV2(ctx, &internalpb.ImportRequest{CollectionName: "aaa"}) + assert.NoError(t, err) + assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) + + // no such collection + mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, mockErr) globalMetaCache = mc rsp, err = node.ImportV2(ctx, &internalpb.ImportRequest{CollectionName: "aaa"}) @@ -1626,6 +1635,7 @@ func TestProxy_ImportV2(t *testing.T) { // get schema failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(nil, mockErr) globalMetaCache = mc @@ -1635,6 +1645,7 @@ func TestProxy_ImportV2(t *testing.T) { // get channel failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{ @@ -1659,6 +1670,7 @@ func TestProxy_ImportV2(t *testing.T) { // get partitions failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{ @@ -1673,6 +1685,7 @@ func TestProxy_ImportV2(t *testing.T) { // get partitionID failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{}, @@ -1685,6 +1698,7 @@ func TestProxy_ImportV2(t *testing.T) { // no file mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{}, @@ -1731,7 +1745,18 @@ func TestProxy_ImportV2(t *testing.T) { assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) node.UpdateStateCode(commonpb.StateCode_Healthy) + // no such database + mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, mockErr) + globalMetaCache = mc + rsp, err = node.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{}) + assert.NoError(t, err) + assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) + // normal case + mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) + globalMetaCache = mc dataCoord := mocks.NewMockDataCoordClient(t) dataCoord.EXPECT().GetImportProgress(mock.Anything, mock.Anything).Return(nil, nil) node.dataCoord = dataCoord @@ -1749,8 +1774,19 @@ func TestProxy_ImportV2(t *testing.T) { assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) node.UpdateStateCode(commonpb.StateCode_Healthy) - // normal case + // no such database mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, mockErr) + globalMetaCache = mc + rsp, err = node.ListImports(ctx, &internalpb.ListImportsRequest{ + CollectionName: "col", + }) + assert.NoError(t, err) + assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) + + // normal case + mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) globalMetaCache = mc dataCoord := mocks.NewMockDataCoordClient(t) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index c500bff6461f8..f106d6c387c86 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -4570,6 +4570,7 @@ func TestProxy_Import(t *testing.T) { proxy.UpdateStateCode(commonpb.StateCode_Healthy) mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{}, @@ -4610,6 +4611,10 @@ func TestProxy_Import(t *testing.T) { proxy := &Proxy{} proxy.UpdateStateCode(commonpb.StateCode_Healthy) + mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) + globalMetaCache = mc + dataCoord := mocks.NewMockDataCoordClient(t) dataCoord.EXPECT().GetImportProgress(mock.Anything, mock.Anything).Return(&internalpb.GetImportProgressResponse{ Status: merr.Success(), @@ -4635,6 +4640,10 @@ func TestProxy_Import(t *testing.T) { proxy := &Proxy{} proxy.UpdateStateCode(commonpb.StateCode_Healthy) + mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) + globalMetaCache = mc + dataCoord := mocks.NewMockDataCoordClient(t) dataCoord.EXPECT().ListImports(mock.Anything, mock.Anything).Return(&internalpb.ListImportsResponse{ Status: merr.Success(),