Skip to content

Commit

Permalink
Fix FileExist utility and generate sha for modelspec (kubeflow#1318)
Browse files Browse the repository at this point in the history
* Create sha for modelspec

* Encode/Decode model spec

* Add model dir sync test
  • Loading branch information
yuzisun authored Feb 1, 2021
1 parent ec9e8ca commit a0a52f8
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 70 deletions.
30 changes: 21 additions & 9 deletions pkg/agent/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ package agent

import (
"encoding/hex"
"encoding/json"
"fmt"
"github.com/kubeflow/kfserving/pkg/agent/storage"
"github.com/kubeflow/kfserving/pkg/apis/serving/v1alpha1"
"github.com/pkg/errors"
"go.uber.org/zap"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"strings"
Expand All @@ -38,25 +41,34 @@ var SupportedProtocols = []storage.Protocol{storage.S3, storage.GCS}

func (d *Downloader) DownloadModel(modelName string, modelSpec *v1alpha1.ModelSpec) error {
if modelSpec != nil {
modelUri := modelSpec.StorageURI
hashModelUri := hash(modelUri)
hashFramework := hash(modelSpec.Framework)
hashMemory := hash(modelSpec.Memory.String())
sha256 := storage.AsSha256(modelSpec)
successFile := filepath.Join(d.ModelDir, modelName,
fmt.Sprintf("SUCCESS.%s.%s.%s", hashModelUri, hashFramework, hashMemory))
d.Logger.Infof("Downloading %s to model dir %s", modelUri, d.ModelDir)
fmt.Sprintf("SUCCESS.%s", sha256))
d.Logger.Infof("Downloading %s to model dir %s", modelSpec.StorageURI, d.ModelDir)
// Download if the event there is a success file and the event is one which we wish to Download
if !storage.FileExists(successFile) {
if err := d.download(modelName, modelUri); err != nil {
_, err := os.Stat(successFile)
if os.IsNotExist(err) {
if err := d.download(modelName, modelSpec.StorageURI); err != nil {
return errors.Wrapf(err, "failed to download model")
}
file, createErr := storage.Create(successFile)
defer file.Close()
if createErr != nil {
return errors.Wrapf(createErr, "failed to create success file")
}
} else {
encodedJson, err := json.Marshal(modelSpec)
if err != nil {
return errors.Wrapf(createErr, "failed to encode model spec")
}
err = ioutil.WriteFile(successFile, encodedJson, 0644)
if err != nil {
return errors.Wrapf(createErr, "failed to write the success file")
}
d.Logger.Infof("Creating successFile %s", successFile)
} else if err == nil {
d.Logger.Infof("Model successFile exists already for %s", modelName)
} else {
d.Logger.Errorf("Model successFile error %v", err)
}
}
return nil
Expand Down
17 changes: 16 additions & 1 deletion pkg/agent/storage/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package storage
import (
gstorage "cloud.google.com/go/storage"
"context"
"crypto/sha256"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -35,7 +36,21 @@ import (

func FileExists(filename string) bool {
info, err := os.Stat(filename)
return !os.IsNotExist(err) && !info.IsDir()
if os.IsNotExist(err) {
return false
}
if err == nil && !info.IsDir() {
return true
} else {
return false
}
}

func AsSha256(o interface{}) string {
h := sha256.New()
h.Write([]byte(fmt.Sprintf("%v", o)))

return fmt.Sprintf("%x", h.Sum(nil))
}

func Create(fileName string) (*os.File, error) {
Expand Down
90 changes: 33 additions & 57 deletions pkg/agent/syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ limitations under the License.
package agent

import (
"encoding/hex"
"encoding/json"
"fmt"
"github.com/kubeflow/kfserving/pkg/apis/serving/v1alpha1"
"github.com/pkg/errors"
"k8s.io/apimachinery/pkg/api/resource"
"go.uber.org/zap"
"io/ioutil"
"os"
"path/filepath"
"strings"
Expand All @@ -31,34 +32,39 @@ type FileError error

var NoSuccessFile FileError = fmt.Errorf("no success file can be found")

func SyncModelDir(modelDir string) (map[string]modelWrapper, error) {
func SyncModelDir(modelDir string, logger *zap.SugaredLogger) (map[string]modelWrapper, error) {
logger.Infof("Syncing from model dir %s", modelDir)
modelTracker := make(map[string]modelWrapper)
err := filepath.Walk(modelDir, func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
modelName := info.Name()
ierr := filepath.Walk(path, func(path string, f os.FileInfo, _ error) error {
if !f.IsDir() {
base := filepath.Base(path)
baseSplit := strings.SplitN(base, ".", 4)
if baseSplit[0] == "SUCCESS" {
if spec, e := successParse(baseSplit); e != nil {
return errors.Wrapf(err, "error parsing SUCCESS file")
} else {
modelTracker[modelName] = modelWrapper{
Spec: spec,
stale: true,
}
return nil
}
}
if !info.IsDir() {
fileName := info.Name()
if strings.HasPrefix(fileName, "SUCCESS.") {
logger.Infof("Syncing from model success file %v", fileName)
dir := filepath.Dir(path)
dirSplit := strings.Split(dir, "/")
if len(dirSplit) < 2 {
return errors.Wrapf(err, "invalid model path")
}
return NoSuccessFile
})
switch ierr {
case NoSuccessFile:
return nil
default:
return errors.Wrapf(ierr, "failed to parse success file")
modelName := dirSplit[len(dirSplit)-1]

jsonFile, err := os.Open(path)
if err != nil {
return errors.Wrapf(err, "failed to parse success file")
}
byteValue, err := ioutil.ReadAll(jsonFile)
if err != nil {
return errors.Wrapf(err, "failed to read from model spec")
}
modelSpec := &v1alpha1.ModelSpec{}
err = json.Unmarshal(byteValue, &modelSpec)
if err != nil {
return errors.Wrapf(err, "failed to unmarshal model spec")
}
modelTracker[dirSplit[len(dirSplit)-1]] = modelWrapper{
Spec: modelSpec,
stale: true,
}
logger.Infof("recovered model %s with spec %+v", modelName, modelSpec)
}
}
return nil
Expand All @@ -68,33 +74,3 @@ func SyncModelDir(modelDir string) (map[string]modelWrapper, error) {
}
return modelTracker, nil
}

func successParse(baseSplit []string) (*v1alpha1.ModelSpec, error) {
storageURI, err := unhash(baseSplit[1])
errorMessage := "unable to unhash the SUCCESS file, maybe the SUCCESS file has been modified?"
if err != nil {
return nil, errors.Wrapf(err, errorMessage)
}
framework, err := unhash(baseSplit[2])
if err != nil {
return nil, errors.Wrapf(err, errorMessage)
}
memory, err := unhash(baseSplit[3])
if err != nil {
return nil, errors.Wrapf(err, errorMessage)
}
memoryResource := resource.MustParse(memory)
return &v1alpha1.ModelSpec{
StorageURI: storageURI,
Framework: framework,
Memory: memoryResource,
}, nil
}

func unhash(s string) (string, error) {
decoded, err := hex.DecodeString(s)
if err != nil {
return "", nil
}
return string(decoded), nil
}
6 changes: 3 additions & 3 deletions pkg/agent/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ type Watcher struct {
}

func NewWatcher(configDir string, modelDir string, logger *zap.SugaredLogger) Watcher {
modelTracker, err := SyncModelDir(modelDir)
modelTracker, err := SyncModelDir(modelDir, logger)
if err != nil {
logger.Error(err, "Failed to sync model dir")
logger.Errorf("Failed to sync model dir %v", err)
}
watcher := Watcher{
configDir: configDir,
Expand All @@ -50,7 +50,7 @@ func NewWatcher(configDir string, modelDir string, logger *zap.SugaredLogger) Wa
modelConfigFile := fmt.Sprintf("%s/%s", configDir, constants.ModelConfigFileName)
err = watcher.syncModelConfig(modelConfigFile)
if err != nil {
logger.Error(err, "Failed to sync model config file")
logger.Errorf("Failed to sync model config file %v", err)
}
return watcher
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/agent/watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
. "github.com/onsi/gomega"
"go.uber.org/zap"
"io/ioutil"
"k8s.io/apimachinery/pkg/api/resource"
logger "log"
"os"
"path/filepath"
Expand Down Expand Up @@ -65,13 +66,15 @@ var _ = Describe("Watcher", func() {
Spec: v1alpha1.ModelSpec{
StorageURI: "s3://models/model1",
Framework: "sklearn",
Memory: resource.MustParse("100Mi"),
},
},
{
Name: "model2",
Spec: v1alpha1.ModelSpec{
StorageURI: "s3://models/model2",
Framework: "sklearn",
Memory: resource.MustParse("100Mi"),
},
},
}
Expand All @@ -96,6 +99,8 @@ var _ = Describe("Watcher", func() {
Eventually(func() int { return len(puller.channelMap) }).Should(Equal(0))
Eventually(func() int { return puller.opStats["model1"][Add] }).Should(Equal(1))
Eventually(func() int { return puller.opStats["model2"][Add] }).Should(Equal(1))
modelSpecMap, _ := SyncModelDir(modelDir+"/test1", watcher.logger)
Expect(watcher.modelTracker).Should(Equal(modelSpecMap))
})
})
})
Expand Down

0 comments on commit a0a52f8

Please sign in to comment.