Skip to content

Improve classification stats reliability #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 23, 2019
Merged
15 changes: 15 additions & 0 deletions pkg/lib/aws/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,21 @@ func (c *Client) ReadBytesFromS3(key string) ([]byte, error) {
return buf.Bytes(), nil
}

func (c *Client) ListPrefix(prefix string, maxResults int64) ([]*s3.Object, error) {
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(c.Bucket),
Prefix: aws.String(prefix),
MaxKeys: aws.Int64(maxResults),
}

output, err := c.s3Client.ListObjectsV2(listObjectsInput)
if err != nil {
return nil, errors.Wrap(err, prefix)
}

return output.Contents, nil
}

func (c *Client) DeleteFromS3ByPrefix(prefix string, continueIfFailure bool) error {
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(c.Bucket),
Expand Down
1 change: 1 addition & 0 deletions pkg/operator/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func New(
ctx.PythonPackages = pythonPackages

apis, err := getAPIs(userconf, ctx.DeploymentVersion, files, pythonPackages)

if err != nil {
return nil, err
}
Expand Down
73 changes: 30 additions & 43 deletions pkg/operator/workloads/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.
package workloads

import (
"encoding/base64"
"fmt"
"path/filepath"
"strings"
"time"

Expand Down Expand Up @@ -62,7 +64,7 @@ func GetMetrics(appName, apiName string) (*schema.APIMetrics, error) {

requestList := []func() error{}
if realTimeStart.Before(realTimeEnd) {
requestList = append(requestList, getAPIMetricsFunc(appName, api, 1, &realTimeStart, &realTimeEnd, &realTimeMetrics))
requestList = append(requestList, getAPIMetricsFunc(ctx, api, 1, &realTimeStart, &realTimeEnd, &realTimeMetrics))
}

if apiStartTime.Before(realTimeStart) {
Expand All @@ -75,7 +77,7 @@ func GetMetrics(appName, apiName string) (*schema.APIMetrics, error) {
} else {
batchStart = twoWeeksAgo
}
requestList = append(requestList, getAPIMetricsFunc(appName, api, 60*60, &batchStart, &batchEnd, &batchMetrics))
requestList = append(requestList, getAPIMetricsFunc(ctx, api, 60*60, &batchStart, &batchEnd, &batchMetrics))
}

if len(requestList) != 0 {
Expand All @@ -89,9 +91,9 @@ func GetMetrics(appName, apiName string) (*schema.APIMetrics, error) {
return &mergedMetrics, nil
}

func getAPIMetricsFunc(appName string, api *context.API, period int64, startTime *time.Time, endTime *time.Time, apiMetrics *schema.APIMetrics) func() error {
func getAPIMetricsFunc(ctx *context.Context, api *context.API, period int64, startTime *time.Time, endTime *time.Time, apiMetrics *schema.APIMetrics) func() error {
return func() error {
metricDataResults, err := queryMetrics(appName, api, period, startTime, endTime)
metricDataResults, err := queryMetrics(ctx, api, period, startTime, endTime)
if err != nil {
return err
}
Expand All @@ -116,20 +118,20 @@ func getAPIMetricsFunc(appName string, api *context.API, period int64, startTime
}
}

func queryMetrics(appName string, api *context.API, period int64, startTime *time.Time, endTime *time.Time) ([]*cloudwatch.MetricDataResult, error) {
networkDataQueries := getNetworkStatsDef(appName, api, period)
func queryMetrics(ctx *context.Context, api *context.API, period int64, startTime *time.Time, endTime *time.Time) ([]*cloudwatch.MetricDataResult, error) {
networkDataQueries := getNetworkStatsDef(ctx.App.Name, api, period)
latencyMetrics := getLatencyMetricsDef(api.Path, period)
allMetrics := append(latencyMetrics, networkDataQueries...)

if api.Tracker != nil {
if api.Tracker.ModelType == userconfig.ClassificationModelType {
classMetrics, err := getClassesMetricDef(appName, api, period)
classMetrics, err := getClassesMetricDef(ctx, api, period)
if err != nil {
return nil, err
}
allMetrics = append(allMetrics, classMetrics...)
} else {
regressionMetrics := getRegressionMetricDef(appName, api, period)
regressionMetrics := getRegressionMetricDef(ctx.App.Name, api, period)
allMetrics = append(allMetrics, regressionMetrics...)
}
}
Expand Down Expand Up @@ -397,64 +399,49 @@ func getNetworkStatsDef(appName string, api *context.API, period int64) []*cloud
return networkDataQueries
}

func getClassesMetricDef(appName string, api *context.API, period int64) ([]*cloudwatch.MetricDataQuery, error) {
listMetricsInput := &cloudwatch.ListMetricsInput{
Namespace: aws.String(config.Cortex.LogGroup),
MetricName: aws.String("Prediction"),
Dimensions: []*cloudwatch.DimensionFilter{
{
Name: aws.String("AppName"),
Value: aws.String(appName),
},
{
Name: aws.String("APIName"),
Value: aws.String(api.Name),
},
{
Name: aws.String("APIID"),
Value: aws.String(api.ID),
},
},
}

listMetricsOutput, err := config.AWS.CloudWatchMetrics.ListMetrics(listMetricsInput)
func getClassesMetricDef(ctx *context.Context, api *context.API, period int64) ([]*cloudwatch.MetricDataQuery, error) {
prefix := filepath.Join(ctx.MetadataRoot, api.ID, "classes") + "/"
classes, err := config.AWS.ListPrefix(prefix, int64(consts.MaxClassesPerRequest))
if err != nil {
return nil, err
}

if listMetricsOutput.Metrics == nil {
if len(classes) == 0 {
return nil, nil
}

classMetricQueries := []*cloudwatch.MetricDataQuery{}

classCount := 0
for i, metric := range listMetricsOutput.Metrics {
if classCount >= consts.MaxClassesPerRequest {
break
}

var className string
for _, dim := range metric.Dimensions {
if *dim.Name == "Class" {
className = *dim.Value
}
for i, classObj := range classes {
classKey := *classObj.Key
urlSplit := strings.Split(classKey, "/")
encodedClassName := urlSplit[len(urlSplit)-1]
decodedBytes, err := base64.URLEncoding.DecodeString(encodedClassName)
if err != nil {
return nil, errors.Wrap(err, "encoded class name", encodedClassName)
}

className := string(decodedBytes)
if len(className) == 0 {
continue
}

classMetricQueries = append(classMetricQueries, &cloudwatch.MetricDataQuery{
Id: aws.String(fmt.Sprintf("id_%d", i)),
MetricStat: &cloudwatch.MetricStat{
Metric: metric,
Metric: &cloudwatch.Metric{
Namespace: aws.String(config.Cortex.LogGroup),
MetricName: aws.String("Prediction"),
Dimensions: append(getAPIDimensions(ctx.App.Name, api), &cloudwatch.Dimension{
Name: aws.String("Class"),
Value: aws.String(className),
}),
},
Stat: aws.String("Sum"),
Period: aws.Int64(period),
},
Label: aws.String("class_" + className),
})
classCount++
}
return classMetricQueries, nil
}
129 changes: 85 additions & 44 deletions pkg/workloads/cortex/lib/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,38 @@
# limitations under the License.


import os
import base64

from cortex.lib.exceptions import UserException, CortexException
from cortex.lib.log import get_logger

logger = get_logger()


def get_classes(ctx, api_name):
api = ctx.apis[api_name]
prefix = os.path.join(ctx.metadata_root, api["id"], "classes")
class_paths = ctx.storage.search(prefix=prefix)
class_set = set()
for class_path in class_paths:
encoded_class_name = class_path.split("/")[-1]
class_set.add(base64.urlsafe_b64decode(encoded_class_name.encode()).decode())
return class_set


def upload_class(ctx, api_name, class_name):
api = ctx.apis[api_name]

try:
ascii_encoded = class_name.encode("ascii") # cloudwatch only supports ascii
encoded_class_name = base64.urlsafe_b64encode(ascii_encoded)
key = os.path.join(ctx.metadata_root, api["id"], "classes", encoded_class_name.decode())
ctx.storage.put_json("", key)
except Exception as e:
raise ValueError("unable to store class {}".format(class_name)) from e


def api_metric_dimensions(ctx, api_name):
api = ctx.apis[api_name]
return [
Expand All @@ -42,72 +68,87 @@ def predictions_per_request_metric(dimensions, prediction_count):
]


def prediction_metrics(dimensions, api, predictions):
metric_list = []
def extract_predicted_values(api, predictions):
predicted_values = []

tracker = api.get("tracker")
for prediction in predictions:
predicted_value = prediction.get(tracker["key"])
if predicted_value is None:
logger.warn(
raise ValueError(
"failed to track key '{}': not found in response payload".format(tracker["key"])
)
return []

if tracker["model_type"] == "classification":
if type(predicted_value) == str or type(predicted_value) == int:
dimensions_with_class = dimensions + [
{"Name": "Class", "Value": str(predicted_value)}
]
metric = {
"MetricName": "Prediction",
"Dimensions": dimensions_with_class,
"Unit": "Count",
"Value": 1,
}

metric_list.append(metric)
else:
logger.warn(
if type(predicted_value) != str and type(predicted_value) != int:
raise ValueError(
"failed to track key '{}': expected type 'str' or 'int' but encountered '{}'".format(
tracker["key"], type(predicted_value)
)
)
return []
else:
if type(predicted_value) == float or type(predicted_value) == int: # allow ints
metric = {
"MetricName": "Prediction",
"Dimensions": dimensions,
"Value": float(predicted_value),
}
metric_list.append(metric)
else:
logger.warn(
if type(predicted_value) != float and type(predicted_value) != int: # allow ints
raise ValueError(
"failed to track key '{}': expected type 'float' or 'int' but encountered '{}'".format(
tracker["key"], type(predicted_value)
)
)
return []
predicted_values.append(predicted_value)

return predicted_values


def prediction_metrics(dimensions, api, predicted_values):
metric_list = []
tracker = api.get("tracker")
for predicted_value in predicted_values:
if tracker["model_type"] == "classification":
dimensions_with_class = dimensions + [{"Name": "Class", "Value": str(predicted_value)}]
metric = {
"MetricName": "Prediction",
"Dimensions": dimensions_with_class,
"Unit": "Count",
"Value": 1,
}

metric_list.append(metric)
else:
metric = {
"MetricName": "Prediction",
"Dimensions": dimensions,
"Value": float(predicted_value),
}
metric_list.append(metric)
return metric_list


def post_request_metrics(ctx, api, response, predictions):
try:
api_name = api["name"]
def cache_classes(ctx, api, predicted_values, class_set):
for predicted_value in predicted_values:
if predicted_value not in class_set:
upload_class(ctx, api["name"], predicted_value)
class_set.add(predicted_value)

api_dimensions = api_metric_dimensions(ctx, api_name)
metrics_list = []
metrics_list += status_code_metric(api_dimensions, response.status_code)

if predictions is not None:
metrics_list += predictions_per_request_metric(api_dimensions, len(predictions))
def post_request_metrics(ctx, api, response, predictions, class_set):
api_name = api["name"]
api_dimensions = api_metric_dimensions(ctx, api_name)
metrics_list = []
metrics_list += status_code_metric(api_dimensions, response.status_code)

if api.get("tracker") is not None:
metrics_list += prediction_metrics(api_dimensions, api, predictions)
ctx.publish_metrics(metrics_list)
if predictions is not None:
metrics_list += predictions_per_request_metric(api_dimensions, len(predictions))

except CortexException as e:
e.wrap("error")
logger.warn(str(e), exc_info=True)
if api.get("tracker") is not None:
try:
predicted_values = extract_predicted_values(api, predictions)

if api["tracker"]["model_type"] == "classification":
cache_classes(ctx, api, predicted_values, class_set)

metrics_list += prediction_metrics(api_dimensions, api, predicted_values)
except Exception as e:
logger.warn(str(e), exc_info=True)

try:
ctx.publish_metrics(metrics_list)
except Exception as e:
logger.warn(str(e), exc_info=True)
22 changes: 1 addition & 21 deletions pkg/workloads/cortex/lib/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, **kwargs):

self.id = self.ctx["id"]
self.key = self.ctx["key"]
self.metadata_root = self.ctx["metadata_root"]
self.cortex_config = self.ctx["cortex_config"]
self.deployment_version = self.ctx["deployment_version"]
self.root = self.ctx["root"]
Expand Down Expand Up @@ -83,9 +84,6 @@ def __init__(self, **kwargs):
)
)

# Internal caches
self._metadatas = {}

# This affects Tensorflow S3 access
os.environ["AWS_REGION"] = self.cortex_config.get("region", "")

Expand Down Expand Up @@ -192,24 +190,6 @@ def upload_resource_status_end(self, exit_code, *resources):
def resource_status_key(self, resource):
return os.path.join(self.status_prefix, resource["id"], resource["workload_id"])

def get_metadata_url(self, resource_id):
return os.path.join(self.ctx["metadata_root"], resource_id + ".json")

def write_metadata(self, resource_id, metadata):
if resource_id in self._metadatas and self._metadatas[resource_id] == metadata:
return

self._metadatas[resource_id] = metadata
self.storage.put_json(metadata, self.get_metadata_url(resource_id))

def get_metadata(self, resource_id, use_cache=True):
if use_cache and resource_id in self._metadatas:
return self._metadatas[resource_id]

metadata = self.storage.get_json(self.get_metadata_url(resource_id), allow_missing=True)
self._metadatas[resource_id] = metadata
return metadata

def publish_metrics(self, metrics):
if self.monitoring is None:
raise CortexException("monitoring client not initialized") # unexpected
Expand Down
Loading