Skip to content

Commit

Permalink
fix race condition in oas
Browse files Browse the repository at this point in the history
  • Loading branch information
Keithwachira committed Apr 23, 2024
1 parent 3e76542 commit f1acfa8
Showing 1 changed file with 18 additions and 29 deletions.
47 changes: 18 additions & 29 deletions gateway/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ const (
oAuthClientTokensKeyPattern = "oauth-data.*oauth-client-tokens.*"
)

var (
ErrRequestMalformed = errors.New("request malformed")
)
var ErrRequestMalformed = errors.New("request malformed")

// apiModifyKeySuccess represents when a Key modification was successful
//
Expand Down Expand Up @@ -143,14 +141,12 @@ func doJSONWrite(w http.ResponseWriter, code int, obj interface{}) {
}

func doJSONExport(w http.ResponseWriter, code int, obj interface{}, fileName string) {

if code != http.StatusOK {
doJSONWrite(w, code, obj)
return
}

stream, err := json.MarshalIndent(obj, "", " ")

if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -165,7 +161,6 @@ func doJSONExport(w http.ResponseWriter, code int, obj interface{}, fileName str
job := instrument.NewJob("SystemAPIError")
job.Event(err.Error())
}

}

type MethodNotAllowedHandler struct{}
Expand Down Expand Up @@ -508,7 +503,7 @@ func (gw *Gateway) handleAddOrUpdate(keyName string, r *http.Request, isHashed b
keyName = gw.generateToken(newSession.OrgID, keyName)
}

//set the original expiry if the content in payload is a past time
// set the original expiry if the content in payload is a past time
if time.Now().After(time.Unix(newSession.Expires, 0)) && newSession.Expires > 1 {
newSession.Expires = originalKey.Expires
}
Expand Down Expand Up @@ -839,12 +834,10 @@ func (gw *Gateway) handleDeleteHashedKeyWithLogs(keyName, orgID, apiID string, r
}

func (gw *Gateway) handleDeleteHashedKey(keyName, orgID, apiID string, resetQuota bool) (interface{}, int) {

session, ok := gw.GlobalSessionManager.SessionDetail(orgID, keyName, true)
keyName = session.KeyID
if !ok {
return apiError("There is no such key found"), http.StatusNotFound

}

if apiID == "-1" {
Expand Down Expand Up @@ -939,7 +932,7 @@ func (gw *Gateway) handleAddOrUpdatePolicy(polID string, r *http.Request) (inter
return apiError("Marshalling failed"), http.StatusInternalServerError
}

if err := ioutil.WriteFile(polFilePath, asByte, 0644); err != nil {
if err := ioutil.WriteFile(polFilePath, asByte, 0o644); err != nil {

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.
log.Error("Failed to create file! - ", err)
return apiError("Failed to create file!"), http.StatusInternalServerError
}
Expand Down Expand Up @@ -1039,10 +1032,11 @@ func (gw *Gateway) handleGetAPIOAS(apiID string, modePublic bool) (interface{},

obj, code := gw.handleGetAPI(apiID, true)
if apiOAS, ok := obj.(*oas.OAS); ok && modePublic {
apiOAS.RemoveTykExtension()
oasCopy := *apiOAS
oasCopy.RemoveTykExtension()
return oasCopy, code
}
return obj, code

}

func (gw *Gateway) handleAddApi(r *http.Request, fs afero.Fs, oasEndpoint bool) (interface{}, int) {
Expand Down Expand Up @@ -1266,7 +1260,7 @@ func (gw *Gateway) writeToFile(fs afero.Fs, newDef interface{}, filename string)
return errors.New("marshalling failed"), http.StatusInternalServerError
}

if err := ioutil.WriteFile(defFilePath, asByte, 0644); err != nil {
if err := ioutil.WriteFile(defFilePath, asByte, 0o644); err != nil {

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.
This path depends on a
user-provided value
.

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.
This path depends on a
user-provided value
.
This path depends on a
user-provided value
.
This path depends on a
user-provided value
.
This path depends on a
user-provided value
.
log.Infof("EL file path: %v", defFilePath)
log.Error("Failed to create file! - ", err)
return errors.New("file object creation failed, write error"), http.StatusInternalServerError
Expand Down Expand Up @@ -1503,7 +1497,6 @@ func (gw *Gateway) apiOASPatchHandler(w http.ResponseWriter, r *http.Request) {
}

reqBodyInBytes, oasObj, err := extractOASObjFromReq(r.Body)

if err != nil {
doJSONWrite(w, http.StatusBadRequest, apiError(err.Error()))
return
Expand Down Expand Up @@ -2049,7 +2042,6 @@ func (gw *Gateway) createKeyHandler(w http.ResponseWriter, r *http.Request) {
doJSONWrite(w, http.StatusBadRequest, apiError("Failed to create key, keys must have at least one Access Rights record set."))
return
}

}

obj := apiModifyKeySuccess{
Expand Down Expand Up @@ -2484,7 +2476,6 @@ func (gw *Gateway) invalidateOauthRefresh(w http.ResponseWriter, r *http.Request
}

func (gw *Gateway) rotateOauthClientHandler(w http.ResponseWriter, r *http.Request) {

apiID := mux.Vars(r)["apiID"]
keyName := mux.Vars(r)["keyName"]

Expand All @@ -2498,7 +2489,7 @@ func (gw *Gateway) getApisForOauthApp(w http.ResponseWriter, r *http.Request) {
appID := mux.Vars(r)["appID"]
orgID := r.FormValue("orgID")

//get all organization apis
// get all organization apis
apisIds := gw.getApisIdsForOrg(orgID)

for index := range apisIds {
Expand Down Expand Up @@ -2717,12 +2708,14 @@ func (gw *Gateway) handleDeleteOAuthClient(keyName, apiID string) (interface{},
return apiError("OAuth Client ID not found"), http.StatusNotFound
}

const oAuthNotPropagatedErr = "OAuth client list isn't available or hasn't been propagated yet."
const oAuthClientNotFound = "OAuth client not found"
const oauthClientIdEmpty = "client_id is required"
const oauthClientSecretEmpty = "client_secret is required"
const oauthClientSecretWrong = "client secret is wrong"
const oauthTokenEmpty = "token is required"
const (
oAuthNotPropagatedErr = "OAuth client list isn't available or hasn't been propagated yet."
oAuthClientNotFound = "OAuth client not found"
oauthClientIdEmpty = "client_id is required"
oauthClientSecretEmpty = "client_secret is required"
oauthClientSecretWrong = "client secret is wrong"
oauthTokenEmpty = "token is required"
)

func (gw *Gateway) getApiClients(apiID string) ([]ExtendedOsinClientInterface, apiStatusMessage, int) {
var err error
Expand Down Expand Up @@ -2759,7 +2752,6 @@ func (gw *Gateway) getApiClients(apiID string) ([]ExtendedOsinClientInterface, a

// List Clients
func (gw *Gateway) getOauthClients(apiID string) (interface{}, int) {

clientData, _, apiStatusCode := gw.getApiClients(apiID)

if apiStatusCode != 200 {
Expand Down Expand Up @@ -2872,7 +2864,6 @@ func (gw *Gateway) invalidateCacheHandler(w http.ResponseWriter, r *http.Request

func (gw *Gateway) RevokeTokenHandler(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()

if err != nil {
doJSONWrite(w, http.StatusBadRequest, apiError("cannot parse form. Form malformed"))
return
Expand Down Expand Up @@ -2937,7 +2928,6 @@ func (gw *Gateway) GetStorageForApi(apiID string) (ExtendedOsinStorageInterface,

func (gw *Gateway) RevokeAllTokensHandler(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()

if err != nil {
doJSONWrite(w, http.StatusBadRequest, apiError("cannot parse form. Form malformed"))
return
Expand All @@ -2959,7 +2949,7 @@ func (gw *Gateway) RevokeAllTokensHandler(w http.ResponseWriter, r *http.Request

apis := gw.getApisForOauthClientId(clientId, orgId)
if len(apis) == 0 {
//if api is 0 is because the client wasn't found
// if api is 0 is because the client wasn't found
doJSONWrite(w, http.StatusNotFound, apiError("oauth client doesn't exist"))
return
}
Expand All @@ -2986,7 +2976,6 @@ func (gw *Gateway) RevokeAllTokensHandler(w http.ResponseWriter, r *http.Request
func (gw *Gateway) validateOAS(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
reqBodyInBytes, oasObj, err := extractOASObjFromReq(r.Body)

if err != nil {
doJSONWrite(w, http.StatusBadRequest, apiError(err.Error()))
return
Expand Down Expand Up @@ -3070,6 +3059,7 @@ func setContext(r *http.Request, ctx context.Context) {
r2 := r.WithContext(ctx)
*r = *r2
}

func setCtxValue(r *http.Request, key, val interface{}) {
setContext(r, context.WithValue(r.Context(), key, val))
}
Expand Down Expand Up @@ -3420,7 +3410,6 @@ var createOauthClientSecret = func() string {

// invalidate tokens if we had a new policy
func invalidateTokens(prevClient ExtendedOsinClientInterface, updatedClient OAuthClient, oauthManager *OAuthManager) {

if prevPolicy := prevClient.GetPolicyID(); prevPolicy != "" && prevPolicy != updatedClient.PolicyID {
tokenList, err := oauthManager.OsinServer.Storage.GetClientTokens(updatedClient.ClientID)
if err != nil {
Expand Down

0 comments on commit f1acfa8

Please sign in to comment.