Skip to content

Commit

Permalink
rbd: read volumeId from source
Browse files Browse the repository at this point in the history
read the volumeID from replication
source if the ID is missing read
it from req VolumeId as a fallback.

Signed-off-by: Madhu Rajanna <madhupr007@gmail.com>
  • Loading branch information
Madhu-1 committed Jun 24, 2024
1 parent 0b65d28 commit 385930d
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 12 deletions.
13 changes: 7 additions & 6 deletions internal/csi-addons/rbd/replication.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"strings"
"time"

csicommon "github.com/ceph/ceph-csi/internal/csi-common"
corerbd "github.com/ceph/ceph-csi/internal/rbd"
"github.com/ceph/ceph-csi/internal/util"
"github.com/ceph/ceph-csi/internal/util/log"
Expand Down Expand Up @@ -247,7 +248,7 @@ func validateSchedulingInterval(interval string) error {
func (rs *ReplicationServer) EnableVolumeReplication(ctx context.Context,
req *replication.EnableVolumeReplicationRequest,
) (*replication.EnableVolumeReplicationResponse, error) {
volumeID := req.GetVolumeId()
volumeID := csicommon.GetIDFromReplication(req)
if volumeID == "" {
return nil, status.Error(codes.InvalidArgument, "empty volume ID in request")
}
Expand Down Expand Up @@ -329,7 +330,7 @@ func (rs *ReplicationServer) EnableVolumeReplication(ctx context.Context,
func (rs *ReplicationServer) DisableVolumeReplication(ctx context.Context,
req *replication.DisableVolumeReplicationRequest,
) (*replication.DisableVolumeReplicationResponse, error) {
volumeID := req.GetVolumeId()
volumeID := csicommon.GetIDFromReplication(req)
if volumeID == "" {
return nil, status.Error(codes.InvalidArgument, "empty volume ID in request")
}
Expand Down Expand Up @@ -404,7 +405,7 @@ func (rs *ReplicationServer) DisableVolumeReplication(ctx context.Context,
func (rs *ReplicationServer) PromoteVolume(ctx context.Context,
req *replication.PromoteVolumeRequest,
) (*replication.PromoteVolumeResponse, error) {
volumeID := req.GetVolumeId()
volumeID := csicommon.GetIDFromReplication(req)
if volumeID == "" {
return nil, status.Error(codes.InvalidArgument, "empty volume ID in request")
}
Expand Down Expand Up @@ -504,7 +505,7 @@ func (rs *ReplicationServer) PromoteVolume(ctx context.Context,
func (rs *ReplicationServer) DemoteVolume(ctx context.Context,
req *replication.DemoteVolumeRequest,
) (*replication.DemoteVolumeResponse, error) {
volumeID := req.GetVolumeId()
volumeID := csicommon.GetIDFromReplication(req)
if volumeID == "" {
return nil, status.Error(codes.InvalidArgument, "empty volume ID in request")
}
Expand Down Expand Up @@ -622,7 +623,7 @@ func checkRemoteSiteStatus(ctx context.Context, mirrorStatus *librbd.GlobalMirro
func (rs *ReplicationServer) ResyncVolume(ctx context.Context,
req *replication.ResyncVolumeRequest,
) (*replication.ResyncVolumeResponse, error) {
volumeID := req.GetVolumeId()
volumeID := csicommon.GetIDFromReplication(req)
if volumeID == "" {
return nil, status.Error(codes.InvalidArgument, "empty volume ID in request")
}
Expand Down Expand Up @@ -836,7 +837,7 @@ func getGRPCError(err error) error {
func (rs *ReplicationServer) GetVolumeReplicationInfo(ctx context.Context,
req *replication.GetVolumeReplicationInfoRequest,
) (*replication.GetVolumeReplicationInfoResponse, error) {
volumeID := req.GetVolumeId()
volumeID := csicommon.GetIDFromReplication(req)
if volumeID == "" {
return nil, status.Error(codes.InvalidArgument, "empty volume ID in request")
}
Expand Down
49 changes: 43 additions & 6 deletions internal/csi-common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,43 @@ func NewMiddlewareServerOption() grpc.ServerOption {
return grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(middleWare...))
}

// GetIDFromReplication returns the volumeID for Replication.
func GetIDFromReplication(req interface{}) string {
getID := func(r interface {
GetVolumeId() string
GetReplicationSource() *replication.ReplicationSource
},
) string {
reqID := ""
src := r.GetReplicationSource()
if src != nil && src.GetVolume() != nil {
reqID = src.GetVolume().GetVolumeId()
}
if reqID == "" {
reqID = r.GetVolumeId() //nolint:nolintlint,staticcheck // req.VolumeId is deprecated
}

return reqID
}

switch r := req.(type) {
case *replication.EnableVolumeReplicationRequest:
return getID(r)
case *replication.DisableVolumeReplicationRequest:
return getID(r)
case *replication.PromoteVolumeRequest:
return getID(r)
case *replication.DemoteVolumeRequest:
return getID(r)
case *replication.ResyncVolumeRequest:
return getID(r)
case *replication.GetVolumeReplicationInfoRequest:
return getID(r)
default:
return ""
}
}

func getReqID(req interface{}) string {
// if req is nil empty string will be returned
reqID := ""
Expand Down Expand Up @@ -156,17 +193,17 @@ func getReqID(req interface{}) string {

// Replication
case *replication.EnableVolumeReplicationRequest:
reqID = r.GetVolumeId()
reqID = GetIDFromReplication(r)
case *replication.DisableVolumeReplicationRequest:
reqID = r.GetVolumeId()
reqID = GetIDFromReplication(r)
case *replication.PromoteVolumeRequest:
reqID = r.GetVolumeId()
reqID = GetIDFromReplication(r)
case *replication.DemoteVolumeRequest:
reqID = r.GetVolumeId()
reqID = GetIDFromReplication(r)
case *replication.ResyncVolumeRequest:
reqID = r.GetVolumeId()
reqID = GetIDFromReplication(r)
case *replication.GetVolumeReplicationInfoRequest:
reqID = r.GetVolumeId()
reqID = GetIDFromReplication(r)
}

return reqID
Expand Down
56 changes: 56 additions & 0 deletions internal/csi-common/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,62 @@ func TestGetReqID(t *testing.T) {
&replication.GetVolumeReplicationInfoRequest{
VolumeId: fakeID,
},

// volumeId is set in ReplicationSource
&replication.EnableVolumeReplicationRequest{
ReplicationSource: &replication.ReplicationSource{
Type: &replication.ReplicationSource_Volume{
Volume: &replication.ReplicationSource_VolumeSource{
VolumeId: fakeID,
},
},
},
},
&replication.DisableVolumeReplicationRequest{
ReplicationSource: &replication.ReplicationSource{
Type: &replication.ReplicationSource_Volume{
Volume: &replication.ReplicationSource_VolumeSource{
VolumeId: fakeID,
},
},
},
},
&replication.PromoteVolumeRequest{
ReplicationSource: &replication.ReplicationSource{
Type: &replication.ReplicationSource_Volume{
Volume: &replication.ReplicationSource_VolumeSource{
VolumeId: fakeID,
},
},
},
},
&replication.DemoteVolumeRequest{
ReplicationSource: &replication.ReplicationSource{
Type: &replication.ReplicationSource_Volume{
Volume: &replication.ReplicationSource_VolumeSource{
VolumeId: fakeID,
},
},
},
},
&replication.ResyncVolumeRequest{
ReplicationSource: &replication.ReplicationSource{
Type: &replication.ReplicationSource_Volume{
Volume: &replication.ReplicationSource_VolumeSource{
VolumeId: fakeID,
},
},
},
},
&replication.GetVolumeReplicationInfoRequest{
ReplicationSource: &replication.ReplicationSource{
Type: &replication.ReplicationSource_Volume{
Volume: &replication.ReplicationSource_VolumeSource{
VolumeId: fakeID,
},
},
},
},
}
for _, r := range req {
if got := getReqID(r); got != fakeID {
Expand Down

0 comments on commit 385930d

Please sign in to comment.