Skip to content

Commit 1cd5c22

Browse files
authored
Add support for multi model caching & live reloading (#1428)
1 parent 8a563f6 commit 1cd5c22

File tree

116 files changed

+10267
-1920
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+10267
-1920
lines changed

cli/cmd/const.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
Copyright 2020 Cortex Labs, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package cmd
18+
19+
const (
20+
_timeFormat = "02 Jan 06 15:04:05 MST"
21+
)

cli/cmd/lib_batch_apis.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ const (
3838
_titleBatchAPI = "batch api"
3939
_titleJobCount = "running jobs"
4040
_titleLatestJobID = "latest job id"
41-
_timeFormat = "02 Jan 2006 15:04:05 MST"
4241
)
4342

4443
func batchAPIsTable(batchAPIs []schema.APIResponse, envNames []string) table.Table {

cli/cmd/lib_realtime_apis.go

Lines changed: 215 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ import (
2222
"io/ioutil"
2323
"net/http"
2424
"sort"
25+
"strconv"
2526
"strings"
2627
"time"
2728

2829
"github.com/cortexlabs/cortex/cli/types/cliconfig"
2930
"github.com/cortexlabs/cortex/pkg/consts"
30-
"github.com/cortexlabs/cortex/pkg/lib/cast"
3131
"github.com/cortexlabs/cortex/pkg/lib/console"
3232
"github.com/cortexlabs/cortex/pkg/lib/errors"
3333
"github.com/cortexlabs/cortex/pkg/lib/json"
@@ -70,8 +70,8 @@ func realtimeAPITable(realtimeAPI schema.APIResponse, env cliconfig.Environment)
7070

7171
out += fmt.Sprintf("\n%s curl %s -X POST -H \"Content-Type: application/json\" -d @sample.json\n", console.Bold("example curl:"), realtimeAPI.Endpoint)
7272

73-
if realtimeAPI.Spec.Predictor.Type == userconfig.TensorFlowPredictorType || realtimeAPI.Spec.Predictor.Type == userconfig.ONNXPredictorType {
74-
out += "\n" + describeModelInput(realtimeAPI.Status, realtimeAPI.Endpoint)
73+
if !(realtimeAPI.Spec.Predictor.Type == userconfig.PythonPredictorType && realtimeAPI.Spec.Predictor.ModelPath == nil && realtimeAPI.Spec.Predictor.Models == nil) {
74+
out += "\n" + describeModelInput(realtimeAPI.Status, realtimeAPI.Spec.Predictor, realtimeAPI.Endpoint)
7575
}
7676

7777
out += titleStr("configuration") + strings.TrimSpace(realtimeAPI.Spec.UserStr(env.Provider))
@@ -232,67 +232,40 @@ func classificationMetricsStr(metrics *metrics.Metrics) string {
232232
return out
233233
}
234234

235-
func describeModelInput(status *status.Status, apiEndpoint string) string {
235+
func describeModelInput(status *status.Status, predictor *userconfig.Predictor, apiEndpoint string) string {
236236
if status.Updated.Ready+status.Stale.Ready == 0 {
237-
return "the model's input schema will be available when the api is live\n"
237+
return "the models' metadata schema will be available when the api is live\n"
238238
}
239239

240-
apiSummary, err := getAPISummary(apiEndpoint)
241-
if err != nil {
242-
return "error retrieving the model's input schema: " + errors.Message(err) + "\n"
243-
}
244-
245-
numRows := 0
246-
for _, inputSignatures := range apiSummary.ModelSignatures {
247-
numRows += len(inputSignatures)
248-
}
249-
250-
usesDefaultModel := false
251-
rows := make([][]interface{}, numRows)
252-
rowNum := 0
253-
for modelName, inputSignatures := range apiSummary.ModelSignatures {
254-
for inputName, inputSignature := range inputSignatures {
255-
shapeStr := make([]string, len(inputSignature.Shape))
256-
for idx, dim := range inputSignature.Shape {
257-
shapeStr[idx] = s.ObjFlatNoQuotes(dim)
258-
}
259-
260-
shapeRowEntry := ""
261-
if len(shapeStr) == 1 && shapeStr[0] == "scalar" {
262-
shapeRowEntry = "scalar"
263-
} else if len(shapeStr) == 1 && shapeStr[0] == "unknown" {
264-
shapeRowEntry = "unknown"
265-
} else {
266-
shapeRowEntry = "(" + strings.Join(shapeStr, ", ") + ")"
267-
}
268-
rows[rowNum] = []interface{}{
269-
modelName,
270-
inputName,
271-
inputSignature.Type,
272-
shapeRowEntry,
273-
}
274-
rowNum++
240+
cachingEnabled := predictor.Models != nil && predictor.Models.CacheSize != nil && predictor.Models.DiskCacheSize != nil
241+
if predictor.Type == userconfig.TensorFlowPredictorType && !cachingEnabled {
242+
apiTFLiveReloadingSummary, err := getAPITFLiveReloadingSummary(apiEndpoint)
243+
if err != nil {
244+
return "error retrieving the models' metadata schema: " + errors.Message(err) + "\n"
275245
}
276-
if modelName == consts.SingleModelName {
277-
usesDefaultModel = true
246+
t, err := parseAPITFLiveReloadingSummary(apiTFLiveReloadingSummary)
247+
if err != nil {
248+
return "error retrieving the model's input schema: " + errors.Message(err) + "\n"
278249
}
250+
return t
279251
}
280252

281-
inputTitle := "input"
282-
if usesDefaultModel {
283-
inputTitle = "model input"
253+
apiModelSummary, err := getAPIModelSummary(apiEndpoint)
254+
if err != nil {
255+
return "error retrieving the models' metadata schema: " + errors.Message(err) + "\n"
284256
}
285-
t := table.Table{
286-
Headers: []table.Header{
287-
{Title: "model name", MaxWidth: 32, Hidden: usesDefaultModel},
288-
{Title: inputTitle, MaxWidth: 32},
289-
{Title: "type", MaxWidth: 10},
290-
{Title: "shape", MaxWidth: 20},
291-
},
292-
Rows: rows,
257+
t, err := parseAPIModelSummary(apiModelSummary)
258+
if err != nil {
259+
return "error retrieving the models' metadata schema: " + errors.Message(err) + "\n"
293260
}
261+
return t
262+
}
294263

295-
return t.MustFormat()
264+
func getModelFromModelID(modelID string) (modelName string, modelVersion int64, err error) {
265+
splitIndex := strings.LastIndex(modelID, "-")
266+
modelName = modelID[:splitIndex]
267+
modelVersion, err = strconv.ParseInt(modelID[splitIndex+1:], 10, 64)
268+
return
296269
}
297270

298271
func makeRequest(request *http.Request) (http.Header, []byte, error) {
@@ -324,7 +297,26 @@ func makeRequest(request *http.Request) (http.Header, []byte, error) {
324297
return response.Header, bodyBytes, nil
325298
}
326299

327-
func getAPISummary(apiEndpoint string) (*schema.APISummary, error) {
300+
func getAPIModelSummary(apiEndpoint string) (*schema.APIModelSummary, error) {
301+
req, err := http.NewRequest("GET", apiEndpoint, nil)
302+
if err != nil {
303+
return nil, errors.Wrap(err, "unable to request api summary")
304+
}
305+
req.Header.Set("Content-Type", "application/json")
306+
_, response, err := makeRequest(req)
307+
if err != nil {
308+
return nil, err
309+
}
310+
311+
var apiModelSummary schema.APIModelSummary
312+
err = json.DecodeWithNumber(response, &apiModelSummary)
313+
if err != nil {
314+
return nil, errors.Wrap(err, "unable to parse api summary response")
315+
}
316+
return &apiModelSummary, nil
317+
}
318+
319+
func getAPITFLiveReloadingSummary(apiEndpoint string) (*schema.APITFLiveReloadingSummary, error) {
328320
req, err := http.NewRequest("GET", apiEndpoint, nil)
329321
if err != nil {
330322
return nil, errors.Wrap(err, "unable to request api summary")
@@ -335,17 +327,179 @@ func getAPISummary(apiEndpoint string) (*schema.APISummary, error) {
335327
return nil, err
336328
}
337329

338-
var apiSummary schema.APISummary
339-
err = json.DecodeWithNumber(response, &apiSummary)
330+
var apiTFLiveReloadingSummary schema.APITFLiveReloadingSummary
331+
err = json.DecodeWithNumber(response, &apiTFLiveReloadingSummary)
340332
if err != nil {
341333
return nil, errors.Wrap(err, "unable to parse api summary response")
342334
}
335+
return &apiTFLiveReloadingSummary, nil
336+
}
343337

344-
for _, inputSignatures := range apiSummary.ModelSignatures {
345-
for _, inputSignature := range inputSignatures {
346-
inputSignature.Shape = cast.JSONNumbers(inputSignature.Shape)
338+
func parseAPIModelSummary(summary *schema.APIModelSummary) (string, error) {
339+
rows := make([][]interface{}, 0)
340+
341+
for modelName, modelMetadata := range summary.ModelMetadata {
342+
latestVersion := int64(0)
343+
for _, version := range modelMetadata.Versions {
344+
v, err := strconv.ParseInt(version, 10, 64)
345+
if err != nil {
346+
return "", err
347+
}
348+
if v > latestVersion {
349+
latestVersion = v
350+
}
347351
}
352+
latestStrVersion := strconv.FormatInt(latestVersion, 10)
353+
354+
for idx, version := range modelMetadata.Versions {
355+
var latestTag string
356+
if latestStrVersion == version {
357+
latestTag = " (latest)"
358+
}
359+
360+
timestamp := modelMetadata.Timestamps[idx]
361+
date := time.Unix(timestamp, 0)
362+
363+
rows = append(rows, []interface{}{
364+
modelName,
365+
version + latestTag,
366+
date.Format(_timeFormat),
367+
})
368+
}
369+
}
370+
371+
_, usesCortexDefaultModelName := summary.ModelMetadata[consts.SingleModelName]
372+
373+
t := table.Table{
374+
Headers: []table.Header{
375+
{
376+
Title: "model name",
377+
MaxWidth: 32,
378+
Hidden: usesCortexDefaultModelName,
379+
},
380+
{
381+
Title: "model version",
382+
MaxWidth: 25,
383+
},
384+
{
385+
Title: "edit time",
386+
MaxWidth: 32,
387+
},
388+
},
389+
Rows: rows,
390+
}
391+
392+
return t.MustFormat(), nil
393+
}
394+
395+
func parseAPITFLiveReloadingSummary(summary *schema.APITFLiveReloadingSummary) (string, error) {
396+
latestVersions := make(map[string]int64)
397+
398+
numRows := 0
399+
models := make(map[string]schema.GenericModelMetadata, 0)
400+
for modelID, modelMetadata := range summary.ModelMetadata {
401+
timestamp := modelMetadata.Timestamp
402+
modelName, modelVersion, err := getModelFromModelID(modelID)
403+
if err != nil {
404+
return "", err
405+
}
406+
if _, ok := models[modelName]; !ok {
407+
models[modelName] = schema.GenericModelMetadata{
408+
Versions: []string{strconv.FormatInt(modelVersion, 10)},
409+
Timestamps: []int64{timestamp},
410+
}
411+
} else {
412+
model := models[modelName]
413+
model.Versions = append(model.Versions, strconv.FormatInt(modelVersion, 10))
414+
model.Timestamps = append(model.Timestamps, timestamp)
415+
models[modelName] = model
416+
}
417+
if _, ok := latestVersions[modelName]; !ok {
418+
latestVersions[modelName] = modelVersion
419+
} else if modelVersion > latestVersions[modelName] {
420+
latestVersions[modelName] = modelVersion
421+
}
422+
numRows += len(modelMetadata.InputSignatures)
423+
}
424+
425+
rows := make([][]interface{}, 0, numRows)
426+
for modelName, model := range models {
427+
latestVersion := latestVersions[modelName]
428+
429+
for _, modelVersion := range model.Versions {
430+
modelID := fmt.Sprintf("%s-%s", modelName, modelVersion)
431+
432+
inputSignatures := summary.ModelMetadata[modelID].InputSignatures
433+
timestamp := summary.ModelMetadata[modelID].Timestamp
434+
versionInt, err := strconv.ParseInt(modelVersion, 10, 64)
435+
if err != nil {
436+
return "", err
437+
}
438+
439+
var applicableTags string
440+
if versionInt == latestVersion {
441+
applicableTags = " (latest)"
442+
}
443+
444+
date := time.Unix(timestamp, 0)
445+
446+
for inputName, inputSignature := range inputSignatures {
447+
shapeStr := make([]string, len(inputSignature.Shape))
448+
for idx, dim := range inputSignature.Shape {
449+
shapeStr[idx] = s.ObjFlatNoQuotes(dim)
450+
}
451+
shapeRowEntry := ""
452+
if len(shapeStr) == 1 && shapeStr[0] == "scalar" {
453+
shapeRowEntry = "scalar"
454+
} else if len(shapeStr) == 1 && shapeStr[0] == "unknown" {
455+
shapeRowEntry = "unknown"
456+
} else {
457+
shapeRowEntry = "(" + strings.Join(shapeStr, ", ") + ")"
458+
}
459+
rows = append(rows, []interface{}{
460+
modelName,
461+
modelVersion + applicableTags,
462+
inputName,
463+
inputSignature.Type,
464+
shapeRowEntry,
465+
date.Format(_timeFormat),
466+
})
467+
}
468+
}
469+
}
470+
471+
_, usesCortexDefaultModelName := summary.ModelMetadata[consts.SingleModelName]
472+
473+
t := table.Table{
474+
Headers: []table.Header{
475+
{
476+
Title: "model name",
477+
MaxWidth: 32,
478+
Hidden: usesCortexDefaultModelName,
479+
},
480+
{
481+
Title: "model version",
482+
MaxWidth: 25,
483+
},
484+
{
485+
Title: "model input",
486+
MaxWidth: 32,
487+
},
488+
{
489+
Title: "type",
490+
MaxWidth: 10,
491+
},
492+
{
493+
Title: "shape",
494+
MaxWidth: 20,
495+
},
496+
{
497+
Title: "edit time",
498+
MaxWidth: 32,
499+
},
500+
},
501+
Rows: rows,
348502
}
349503

350-
return &apiSummary, nil
504+
return t.MustFormat(), nil
351505
}

0 commit comments

Comments
 (0)