diff --git a/internal/provider/client.go b/internal/provider/client.go index a770ed3..84cc9e8 100644 --- a/internal/provider/client.go +++ b/internal/provider/client.go @@ -6,8 +6,6 @@ import ( "fmt" "io" "net/http" - - "github.com/hashicorp/terraform-plugin-framework/diag" ) type authedTransport struct { @@ -27,11 +25,21 @@ func (t *authedTransport) RoundTrip(req *http.Request) (*http.Response, error) { return t.wrapped.RoundTrip(req) } -func do[O interface{}](client *http.Client, req *http.Request, output *O) (err error) { +func delete(client *http.Client, url string) ([]byte, error) { + req, err := http.NewRequest(http.MethodDelete, url, nil) + + return do(client, req, err) +} + +func do(client *http.Client, req *http.Request, e error) ([]byte, error) { + if e != nil { + return nil, fmt.Errorf("unable to form request, got error: %s", e) + } + res, err := client.Do(req) if err != nil { - return err + return nil, err } defer res.Body.Close() @@ -39,52 +47,40 @@ func do[O interface{}](client *http.Client, req *http.Request, output *O) (err e responseBody, err := io.ReadAll(res.Body) if err != nil { - return err + return nil, err } if res.StatusCode >= 400 { - return fmt.Errorf(string(responseBody)) + return nil, fmt.Errorf(string(responseBody)) } - return json.Unmarshal(responseBody, output) + return responseBody, nil } -func get[O interface{}](client *http.Client, diagnostics diag.Diagnostics, url string, output *O) (err error) { - req, err := http.NewRequest(http.MethodGet, url, nil) +func doOut[O interface{}](client *http.Client, req *http.Request, e error, output *O) error { + body, err := do(client, req, e) if err != nil { - diagnostics.AddError("Provider Error", fmt.Sprintf("Unable to form request, got error: %s", err)) - return nil + return err } - return do(client, req, output) + return json.Unmarshal(body, output) } -func call[I interface{}, O interface{}](client *http.Client, diagnostics diag.Diagnostics, method string, url string, input I, output *O) (err error) { - requestBody, err := json.Marshal(input) - - if err != nil { - diagnostics.AddError("Provider Error", fmt.Sprintf("Unable to marshal JSON, got error: %s", err)) - return nil - } - - req, err := http.NewRequest(method, url, bytes.NewBuffer(requestBody)) - - if err != nil { - diagnostics.AddError("Provider Error", fmt.Sprintf("Unable to form request, got error: %s", err)) - return nil - } +func get[O interface{}](client *http.Client, url string, output *O) error { + req, err := http.NewRequest(http.MethodGet, url, nil) - return do(client, req, output) + return doOut(client, req, err, output) } -func delete(client *http.Client, diagnostics diag.Diagnostics, url string) (res *http.Response, err error) { - req, err := http.NewRequest(http.MethodDelete, url, nil) +func call[I interface{}, O interface{}](client *http.Client, method string, url string, input I, output *O) error { + requestBody, err := json.Marshal(input) if err != nil { - diagnostics.AddError("Provider Error", fmt.Sprintf("Unable to form request, got error: %s", err)) - return nil, nil + return fmt.Errorf("unable to marshal JSON, got error: %s", err) } - return client.Do(req) + req, err := http.NewRequest(method, url, bytes.NewBuffer(requestBody)) + + return doOut(client, req, err, output) } diff --git a/internal/provider/client_api.go b/internal/provider/client_api.go index 2e20901..004adb9 100644 --- a/internal/provider/client_api.go +++ b/internal/provider/client_api.go @@ -5,51 +5,123 @@ import ( "net/http" "time" - "github.com/hashicorp/terraform-plugin-framework/diag" + "golang.org/x/exp/slices" ) -func branchList(client *http.Client, diagnostics diag.Diagnostics, projectId string) (BranchListOutput, error) { +func projectWait(client *http.Client, projectId string) error { + var operations OperationListOutput + + for { + err := get(client, fmt.Sprintf("/projects/%s/operations?limit=1", projectId), &operations) + + if err != nil { + return err + } + + if operations.Operations[0].Status == "finished" { + return nil + } + + time.Sleep(5 * time.Second) + } +} + +func branchList(client *http.Client, projectId string) (BranchListOutput, error) { var branches BranchListOutput - err := get(client, diagnostics, fmt.Sprintf("/projects/%s/branches", projectId), &branches) + err := get(client, fmt.Sprintf("/projects/%s/branches", projectId), &branches) return branches, err } -func branchUpdate(client *http.Client, diagnostics diag.Diagnostics, projectId string, branchId string, input BranchUpdateInput) (BranchOutput, error) { +func branchEndpoint(client *http.Client, projectId string, branchId string) (Endpoint, error) { + endpoints, err := endpointList(client, projectId) + + var endpoint Endpoint + + if err != nil { + return endpoint, err + } + + endpointIdx := slices.IndexFunc(endpoints.Endpoints, func(endpoint Endpoint) bool { + return endpoint.BranchId == branchId + }) + + return endpoints.Endpoints[endpointIdx], nil +} + +func branchGet(client *http.Client, projectId string, branchId string) (BranchOutput, error) { + var branch BranchOutput + + err := projectWait(client, projectId) + + if err != nil { + return branch, err + } + + err = get(client, fmt.Sprintf("/projects/%s/branches/%s", projectId, branchId), &branch) + + if err != nil { + return branch, err + } + + if branch.Branch.ProjectId != projectId { + return branch, fmt.Errorf("branch %s does not belong to project %s", branchId, projectId) + } + + return branch, nil +} + +func branchUpdate(client *http.Client, projectId string, branchId string, input BranchUpdateInput) (BranchOutput, error) { var branch BranchOutput - err := call(client, diagnostics, http.MethodPatch, fmt.Sprintf("/projects/%s/branches/%s", projectId, branchId), input, &branch) + err := call(client, http.MethodPatch, fmt.Sprintf("/projects/%s/branches/%s", projectId, branchId), input, &branch) return branch, err } -func endpointList(client *http.Client, diagnostics diag.Diagnostics, projectId string) (EndpointListOutput, error) { +func endpointList(client *http.Client, projectId string) (EndpointListOutput, error) { var endpoints EndpointListOutput - err := get(client, diagnostics, fmt.Sprintf("/projects/%s/endpoints", projectId), &endpoints) + err := get(client, fmt.Sprintf("/projects/%s/endpoints", projectId), &endpoints) return endpoints, err } -func endpointUpdate(client *http.Client, diagnostics diag.Diagnostics, projectId string, endpointId string, input EndpointUpdateInput) (EndpointOutput, error) { +func endpointUpdate(client *http.Client, projectId string, endpointId string, input EndpointUpdateInput) (EndpointOutput, error) { var endpoint EndpointOutput - for { - err := get(client, diagnostics, fmt.Sprintf("/projects/%s/endpoints/%s", projectId, endpointId), &endpoint) + err := projectWait(client, projectId) - if err != nil { - return endpoint, err - } + if err != nil { + return endpoint, err + } - if endpoint.Endpoint.CurrentState != "init" { - break - } + err = call(client, http.MethodPatch, fmt.Sprintf("/projects/%s/endpoints/%s", projectId, endpointId), input, &endpoint) - time.Sleep(time.Second) + return endpoint, err +} + +func databaseDelete(client *http.Client, projectId string, branchId string, name string) error { + err := projectWait(client, projectId) + + if err != nil { + return err } - err := call(client, diagnostics, http.MethodPatch, fmt.Sprintf("/projects/%s/endpoints/%s", projectId, endpointId), input, &endpoint) + _, err = delete(client, fmt.Sprintf("/projects/%s/branches/%s/databases/%s", projectId, branchId, name)) - return endpoint, err + return err +} + +func roleDelete(client *http.Client, projectId string, branchId string, name string) error { + err := projectWait(client, projectId) + + if err != nil { + return err + } + + _, err = delete(client, fmt.Sprintf("/projects/%s/branches/%s/roles/%s", projectId, branchId, name)) + + return err } diff --git a/internal/provider/client_schema.go b/internal/provider/client_schema.go index 9d39760..a602130 100644 --- a/internal/provider/client_schema.go +++ b/internal/provider/client_schema.go @@ -41,6 +41,15 @@ type Endpoint struct { CurrentState string `json:"current_state"` } +type Operation struct { + Id string `json:"id"` + Action string `json:"action"` + Status string `json:"status"` + EndpointId string `json:"endpoint_id"` + BranchId string `json:"branch_id"` + ProjectId string `json:"project_id"` +} + type ProjectOutput struct { Project Project `json:"project"` } @@ -111,3 +120,7 @@ type EndpointUpdateInputEndpoint struct { type EndpointUpdateInput struct { Endpoint EndpointUpdateInputEndpoint `json:"endpoint"` } + +type OperationListOutput struct { + Operations []Operation `json:"operations"` +} diff --git a/internal/provider/resource_project.go b/internal/provider/resource_project.go index eef88b0..96b7265 100644 --- a/internal/provider/resource_project.go +++ b/internal/provider/resource_project.go @@ -262,7 +262,7 @@ func (r *ProjectResource) Create(ctx context.Context, req resource.CreateRequest var project ProjectCreateOutput - err := call(r.client, resp.Diagnostics, http.MethodPost, "/projects", input, &project) + err := call(r.client, http.MethodPost, "/projects", input, &project) if err != nil { resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to create project, got error: %s", err)) @@ -272,8 +272,7 @@ func (r *ProjectResource) Create(ctx context.Context, req resource.CreateRequest tflog.Trace(ctx, "created a project") // Delete the default database. - url := fmt.Sprintf("/projects/%s/branches/%s/databases/%s", project.Project.Id, project.Branch.Id, project.Databases[0].Name) - _, err = delete(r.client, resp.Diagnostics, url) + err = databaseDelete(r.client, project.Project.Id, project.Branch.Id, project.Databases[0].Name) if err != nil { resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to delete default database, got error: %s", err)) @@ -281,8 +280,7 @@ func (r *ProjectResource) Create(ctx context.Context, req resource.CreateRequest } // Delete the default role. - url = fmt.Sprintf("/projects/%s/branches/%s/roles/%s", project.Project.Id, project.Branch.Id, project.Roles[0].Name) - _, err = delete(r.client, resp.Diagnostics, url) + err = roleDelete(r.client, project.Project.Id, project.Branch.Id, project.Roles[0].Name) if err != nil { resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to delete default role, got error: %s", err)) @@ -326,7 +324,7 @@ func (r *ProjectResource) Read(ctx context.Context, req resource.ReadRequest, re var project ProjectOutput - err := get(r.client, resp.Diagnostics, fmt.Sprintf("/projects/%s", data.Id.ValueString()), &project) + err := get(r.client, fmt.Sprintf("/projects/%s", data.Id.ValueString()), &project) if err != nil { resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to read project, got error: %s", err)) @@ -334,7 +332,7 @@ func (r *ProjectResource) Read(ctx context.Context, req resource.ReadRequest, re } // Read all branches of the project - branches, err := branchList(r.client, resp.Diagnostics, project.Project.Id) + branches, err := branchList(r.client, project.Project.Id) if err != nil { resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to read branches of the project, got error: %s", err)) @@ -347,20 +345,14 @@ func (r *ProjectResource) Read(ctx context.Context, req resource.ReadRequest, re }) branch := branches.Branches[branchIdx] - // Read all endpoints of the project - endpoints, err := endpointList(r.client, resp.Diagnostics, project.Project.Id) + // Get the endpoint for the primary branch + endpoint, err := branchEndpoint(r.client, project.Project.Id, branch.Id) if err != nil { - resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to read endpoints of the project, got error: %s", err)) + resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to read endpoint of the project, got error: %s", err)) return } - // Get the endpoint for the primary branch - endpointIdx := slices.IndexFunc(endpoints.Endpoints, func(endpoint Endpoint) bool { - return endpoint.BranchId == branch.Id - }) - endpoint := endpoints.Endpoints[endpointIdx] - data.Id = types.StringValue(project.Project.Id) data.Name = types.StringValue(project.Project.Name) data.PlatformId = types.StringValue(project.Project.PlatformId) @@ -406,7 +398,7 @@ func (r *ProjectResource) Update(ctx context.Context, req resource.UpdateRequest var project ProjectOutput - err := call(r.client, resp.Diagnostics, http.MethodPatch, fmt.Sprintf("/projects/%s", data.Id.ValueString()), input, &project) + err := call(r.client, http.MethodPatch, fmt.Sprintf("/projects/%s", data.Id.ValueString()), input, &project) if err != nil { resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to update project, got error: %s", err)) @@ -427,7 +419,7 @@ func (r *ProjectResource) Update(ctx context.Context, req resource.UpdateRequest }, } - branch, err := branchUpdate(r.client, resp.Diagnostics, data.Id.ValueString(), branchData.Id.ValueString(), branchInput) + branch, err := branchUpdate(r.client, data.Id.ValueString(), branchData.Id.ValueString(), branchInput) if err != nil { resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to update branch, got error: %s", err)) @@ -449,7 +441,7 @@ func (r *ProjectResource) Update(ctx context.Context, req resource.UpdateRequest }, } - endpoint, err := endpointUpdate(r.client, resp.Diagnostics, data.Id.ValueString(), branchEndpointData.Id.ValueString(), endpointInput) + endpoint, err := endpointUpdate(r.client, data.Id.ValueString(), branchEndpointData.Id.ValueString(), endpointInput) if err != nil { resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to update endpoint, got error: %s", err)) @@ -493,7 +485,7 @@ func (r *ProjectResource) Delete(ctx context.Context, req resource.DeleteRequest return } - _, err := delete(r.client, resp.Diagnostics, fmt.Sprintf("/projects/%s", data.Id.ValueString())) + _, err := delete(r.client, fmt.Sprintf("/projects/%s", data.Id.ValueString())) if err != nil { resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to delete project, got error: %s", err)) diff --git a/internal/provider/resource_project_test.go b/internal/provider/resource_project_test.go index 225db82..f21dce9 100644 --- a/internal/provider/resource_project_test.go +++ b/internal/provider/resource_project_test.go @@ -34,7 +34,7 @@ func TestAccProjectResourceDefault(t *testing.T) { resource.TestCheckResourceAttr("neon_project.test", "branch.endpoint.max_cu", "0.25"), ), }, - // // ImportState testing + // ImportState testing { ResourceName: "neon_project.test", ImportState: true, @@ -57,7 +57,7 @@ func TestAccProjectResourceDefault(t *testing.T) { resource.TestCheckResourceAttr("neon_project.test", "branch.endpoint.max_cu", "0.25"), ), }, - // // Update and Read testing + // Update and Read testing { Config: testAccProjectResourceConfigDefaultUpdate("nue-todo-app"), Check: resource.ComposeAggregateTestCheckFunc(