Skip to content

Commit 31bcc57

Browse files
feat: support importing assertions as part of a store (#446)
* refactor: use fgaclient from input param in CreateStoreWithModel * feat: import assertions when importing store * refactor and test: import assertions when importing store * fix linting errors * test: move clientConfig to be defined per test to avoid false positive race detection --------- Co-authored-by: Ewan Harris <ewan.harris@okta.com>
1 parent c160fcf commit 31bcc57

File tree

3 files changed

+344
-16
lines changed

3 files changed

+344
-16
lines changed

cmd/store/create.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import (
2727
"github.com/openfga/cli/cmd/model"
2828
"github.com/openfga/cli/internal/authorizationmodel"
2929
"github.com/openfga/cli/internal/cmdutils"
30-
"github.com/openfga/cli/internal/fga"
3130
"github.com/openfga/cli/internal/output"
3231
)
3332

@@ -48,16 +47,11 @@ func create(fgaClient client.SdkClient, storeName string) (*client.ClientCreateS
4847
}
4948

5049
func CreateStoreWithModel(
51-
clientConfig fga.ClientConfig,
50+
fgaClient client.SdkClient,
5251
storeName string,
5352
inputModel string,
5453
inputFormat authorizationmodel.ModelFormat,
5554
) (*CreateStoreAndModelResponse, error) {
56-
fgaClient, err := clientConfig.GetFgaClient()
57-
if err != nil {
58-
return nil, fmt.Errorf("failed to initialize FGA Client due to %w", err)
59-
}
60-
6155
response := CreateStoreAndModelResponse{}
6256

6357
if storeName == "" {
@@ -73,7 +67,7 @@ func CreateStoreWithModel(
7367

7468
err = fgaClient.SetStoreId(response.Store.Id)
7569
if err != nil {
76-
return nil, err //nolint:wrapcheck
70+
return nil, fmt.Errorf("failed to set store ID: %w", err)
7771
}
7872

7973
if inputModel != "" {
@@ -109,6 +103,10 @@ export FGA_STORE_ID=$(fga store create --model Model.fga | jq -r .store.id)
109103
RunE: func(cmd *cobra.Command, args []string) error {
110104
clientConfig := cmdutils.GetClientConfig(cmd)
111105
storeName, _ := cmd.Flags().GetString("name")
106+
fgaClient, err := clientConfig.GetFgaClient()
107+
if err != nil {
108+
return fmt.Errorf("failed to initialize FGA Client: %w", err)
109+
}
112110

113111
var inputModel string
114112
if err := authorizationmodel.ReadFromInputFileOrArg(
@@ -122,7 +120,7 @@ export FGA_STORE_ID=$(fga store create --model Model.fga | jq -r .store.id)
122120
return err //nolint:wrapcheck
123121
}
124122

125-
response, err := CreateStoreWithModel(clientConfig, storeName, inputModel, createModelInputFormat)
123+
response, err := CreateStoreWithModel(fgaClient, storeName, inputModel, createModelInputFormat)
126124
if err != nil {
127125
return err
128126
}

cmd/store/import.go

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ const (
4949
// createStore creates a new store with the given client configuration and store data.
5050
func createStore(
5151
clientConfig *fga.ClientConfig,
52+
fgaClient client.SdkClient,
5253
storeData *storetest.StoreData,
5354
format authorizationmodel.ModelFormat,
5455
fileName string,
@@ -58,7 +59,7 @@ func createStore(
5859
storeDataName = strings.TrimSuffix(path.Base(fileName), ".fga.yaml")
5960
}
6061

61-
createStoreAndModelResponse, err := CreateStoreWithModel(*clientConfig, storeDataName, storeData.Model, format)
62+
createStoreAndModelResponse, err := CreateStoreWithModel(fgaClient, storeDataName, storeData.Model, format)
6263
if err != nil {
6364
return nil, err
6465
}
@@ -122,13 +123,18 @@ func importStore(
122123
return nil, err
123124
}
124125

125-
if len(storeData.Tuples) == 0 {
126-
return response, nil
126+
if len(storeData.Tuples) != 0 {
127+
err = importTuples(fgaClient, storeData.Tuples, maxTuplesPerWrite, maxParallelRequests)
128+
if err != nil {
129+
return nil, err
130+
}
127131
}
128132

129-
err = importTuples(fgaClient, storeData.Tuples, maxTuplesPerWrite, maxParallelRequests)
130-
if err != nil {
131-
return nil, err
133+
if len(storeData.Tests) != 0 && response.Model != nil {
134+
err = importAssertions(fgaClient, storeData.Tests, response.Store.Id, response.Model.AuthorizationModelId)
135+
if err != nil {
136+
return nil, err
137+
}
132138
}
133139

134140
return response, nil
@@ -143,7 +149,7 @@ func createOrUpdateStore(
143149
fileName string,
144150
) (*CreateStoreAndModelResponse, error) {
145151
if storeID == "" {
146-
return createStore(clientConfig, storeData, format, fileName)
152+
return createStore(clientConfig, fgaClient, storeData, format, fileName)
147153
}
148154

149155
return updateStore(clientConfig, fgaClient, storeData, format, storeID)
@@ -183,6 +189,53 @@ func importTuples(
183189
return nil
184190
}
185191

192+
func importAssertions(
193+
fgaClient client.SdkClient,
194+
modelTests []storetest.ModelTest,
195+
storeID string,
196+
modelID string,
197+
) error {
198+
var assertions []client.ClientAssertion
199+
200+
for _, modelTest := range modelTests {
201+
if len(modelTest.Check) > 0 {
202+
checkAssertions := getCheckAssertions(modelTest.Check)
203+
assertions = append(assertions, checkAssertions...)
204+
}
205+
}
206+
207+
if len(assertions) > 0 {
208+
writeOptions := client.ClientWriteAssertionsOptions{
209+
AuthorizationModelId: &modelID,
210+
StoreId: &storeID,
211+
}
212+
213+
_, err := fgaClient.WriteAssertions(context.Background()).Body(assertions).Options(writeOptions).Execute()
214+
if err != nil {
215+
return fmt.Errorf("failed to import assertions: %w", err)
216+
}
217+
}
218+
219+
return nil
220+
}
221+
222+
func getCheckAssertions(checkTests []storetest.ModelTestCheck) []client.ClientAssertion {
223+
var assertions []client.ClientAssertion
224+
225+
for _, checkTest := range checkTests {
226+
for relation, expectation := range checkTest.Assertions {
227+
assertions = append(assertions, client.ClientAssertion{
228+
User: checkTest.User,
229+
Relation: relation,
230+
Object: checkTest.Object,
231+
Expectation: expectation,
232+
})
233+
}
234+
}
235+
236+
return assertions
237+
}
238+
186239
func createProgressBar(total int) *progressbar.ProgressBar {
187240
return progressbar.NewOptions(total,
188241
progressbar.OptionSetWriter(os.Stderr),

0 commit comments

Comments
 (0)