Skip to content

Commit c3ef14b

Browse files
Add check and set function
1 parent e2ecd2d commit c3ef14b

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

pkg/client/counter/counter.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package counter
1616

1717
import (
1818
"context"
19+
1920
api "github.com/atomix/api/proto/atomix/counter"
2021
"github.com/atomix/api/proto/atomix/headers"
2122
"github.com/atomix/go-client/pkg/client/primitive"
@@ -47,6 +48,9 @@ type Counter interface {
4748

4849
// Decrement decrements the counter by the given delta
4950
Decrement(ctx context.Context, delta int64) (int64, error)
51+
52+
// CAS check the counter value and then updates its current value
53+
CAS(ctx context.Context, expect int64, update int64) (bool, error)
5054
}
5155

5256
// New creates a new counter for the given partitions
@@ -149,6 +153,26 @@ func (c *counter) Decrement(ctx context.Context, delta int64) (int64, error) {
149153
return response.(*api.DecrementResponse).NextValue, nil
150154
}
151155

156+
func (c *counter) CAS(ctx context.Context, expect int64, update int64) (bool, error) {
157+
response, err := c.instance.DoCommand(ctx, func(ctx context.Context, conn *grpc.ClientConn, header *headers.RequestHeader) (*headers.ResponseHeader, interface{}, error) {
158+
client := api.NewCounterServiceClient(conn)
159+
request := &api.CheckAndSetRequest{
160+
Header: header,
161+
Expect: expect,
162+
Update: update,
163+
}
164+
response, err := client.CheckAndSet(ctx, request)
165+
if err != nil {
166+
return nil, nil, err
167+
}
168+
return response.Header, response, nil
169+
})
170+
if err != nil {
171+
return false, err
172+
}
173+
return response.(*api.CheckAndSetResponse).Succeeded, nil
174+
}
175+
152176
func (c *counter) Close(ctx context.Context) error {
153177
return c.instance.Close(ctx)
154178
}

pkg/client/counter/counter_test.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ package counter
1616

1717
import (
1818
"context"
19+
"testing"
20+
1921
"github.com/atomix/go-client/pkg/client/primitive"
2022
"github.com/atomix/go-client/pkg/client/test"
2123
"github.com/stretchr/testify/assert"
22-
"testing"
2324
)
2425

2526
func TestCounterOperations(t *testing.T) {
@@ -69,6 +70,18 @@ func TestCounterOperations(t *testing.T) {
6970
assert.NoError(t, err)
7071
assert.Equal(t, int64(10), value)
7172

73+
casValue, err := counter.CAS(context.TODO(), 15, 25)
74+
assert.NoError(t, err)
75+
assert.Equal(t, false, casValue)
76+
77+
casValue, err = counter.CAS(context.TODO(), 10, 20)
78+
assert.NoError(t, err)
79+
assert.Equal(t, true, casValue)
80+
81+
value, err = counter.Get(context.TODO())
82+
assert.NoError(t, err)
83+
assert.Equal(t, int64(20), value)
84+
7285
err = counter.Close(context.Background())
7386
assert.NoError(t, err)
7487

@@ -80,7 +93,7 @@ func TestCounterOperations(t *testing.T) {
8093

8194
value, err = counter1.Get(context.TODO())
8295
assert.NoError(t, err)
83-
assert.Equal(t, int64(10), value)
96+
assert.Equal(t, int64(20), value)
8497

8598
err = counter1.Close(context.Background())
8699
assert.NoError(t, err)

0 commit comments

Comments
 (0)