Skip to content

Commit eff0902

Browse files
authored
Respond with transformed samples in prediction requests, add json print option (#153)
1 parent 8ee32b8 commit eff0902

File tree

4 files changed

+28
-2
lines changed

4 files changed

+28
-2
lines changed

cli/cmd/predict.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,31 @@ import (
3333
"github.com/cortexlabs/cortex/pkg/operator/api/resource"
3434
)
3535

36+
var predictPrintJSON bool
37+
3638
func init() {
39+
predictCmd.PersistentFlags().BoolVarP(&predictPrintJSON, "json", "j", false, "print the raw json response")
3740
addAppNameFlag(predictCmd)
3841
addEnvFlag(predictCmd)
3942
}
4043

4144
type PredictResponse struct {
4245
ResourceID string `json:"resource_id"`
43-
ClassificationPredictions []ClassificationPrediction `json:"classification_predictions"`
44-
RegressionPredictions []RegressionPrediction `json:"regression_predictions"`
46+
ClassificationPredictions []ClassificationPrediction `json:"classification_predictions,omitempty"`
47+
RegressionPredictions []RegressionPrediction `json:"regression_predictions,omitempty"`
4548
}
4649

4750
type ClassificationPrediction struct {
4851
PredictedClass int `json:"predicted_class"`
4952
PredictedClassReversed interface{} `json:"predicted_class_reversed"`
5053
Probabilities []float64 `json:"probabilities"`
54+
TransformedSample interface{} `json:"transformed_sample"`
5155
}
5256

5357
type RegressionPrediction struct {
5458
PredictedValue float64 `json:"predicted_value"`
5559
PredictedValueReversed interface{} `json:"predicted_value_reversed"`
60+
TransformedSample interface{} `json:"transformed_sample"`
5661
}
5762

5863
var predictCmd = &cobra.Command{
@@ -87,6 +92,16 @@ var predictCmd = &cobra.Command{
8792
errors.Exit(err)
8893
}
8994

95+
if predictPrintJSON {
96+
prettyResp, err := json.Pretty(predictResponse)
97+
if err != nil {
98+
errors.Exit(err)
99+
}
100+
101+
fmt.Println(prettyResp)
102+
return
103+
}
104+
90105
apiID := predictResponse.ResourceID
91106
api := resourcesRes.APIStatuses[apiID]
92107

docs/operator/cli.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Flags:
5858
-a, --app string app name
5959
-e, --env string environment (default "dev")
6060
-h, --help help for predict
61+
-j, --json print the raw json response
6162
```
6263

6364
The `predict` command converts samples from a JSON file into prediction requests and outputs the response. This command is useful for quickly testing model output.

pkg/lib/json/json.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,12 @@ func WriteJSON(obj interface{}, outPath string) error {
7373
}
7474
return nil
7575
}
76+
77+
func Pretty(obj interface{}) (string, error) {
78+
b, err := json.MarshalIndent(obj, "", " ")
79+
if err != nil {
80+
return "", err
81+
}
82+
83+
return string(b), nil
84+
}

pkg/workloads/tf_api/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def run_predict(sample):
193193
prediction_request = create_prediction_request(transformed_sample)
194194
response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0)
195195
result = parse_response_proto(response_proto)
196+
result["transformed_sample"] = transformed_sample
196197
util.log_indent("Raw sample:", indent=4)
197198
util.log_pretty(sample, indent=6)
198199
util.log_indent("Transformed sample:", indent=4)

0 commit comments

Comments
 (0)