Skip to content

Commit

Permalink
Initial inference graph API implementation (kubeflow#1910)
Browse files Browse the repository at this point in the history
* Add inference graph API

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Implement inference graph API

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* add validation and splitter

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Fix inference graph status

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Refactor knative reconciler

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Resolve service urls

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Improve logging

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* add switch and root node configuration.

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Update docs/samples/graph/README.md

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Add inference router dockerfile

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Fix comments and error handling

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Updating single routerType to sequence for more flexisble composition

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Some InferenceGraph code streamlining

- Avoid unnecessary goroutine use
- Avoid most unnecessary unmarshal+remarshals
- Deduplicate logic between router types
- Handle async errors in ensemble case
- Don't require stepName to be set
- Return http error response when applicable
- Simplify pickupRoute func
- Fix incorrect log parameter syntax
- Add webhook validation for unique step names and inference targets

Signed-off-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Fix inference graph spec

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Add condition tests

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* Update graph doc

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

* rename service to serviceName

Signed-off-by: Dan Sun <dsun20@bloomberg.net>

Co-authored-by: iamlovingit <bitfrog@163.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
  • Loading branch information
3 people authored Jun 3, 2022
1 parent 9b720ae commit 6154ffd
Show file tree
Hide file tree
Showing 36 changed files with 2,484 additions and 7 deletions.
24 changes: 20 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ HAS_LINT := $(shell command -v golint;)
# Image URL to use all building/pushing image targets
IMG ?= kserve-controller:latest
AGENT_IMG ?= agent:latest
ROUTER_IMG ?= router:latest
SKLEARN_IMG ?= sklearnserver
XGB_IMG ?= xgbserver
LGB_IMG ?= lgbserver
PYTORCH_IMG ?= pytorchserver
PMML_IMG ?= pmmlserver
PADDLE_IMG ?= paddleserver
ALIBI_IMG ?= alibi-explainer
AIX_IMG ?= aix-explainer
AIX_IMG ?= aix-explainer
STORAGE_INIT_IMG ?= storage-initializer
CRD_OPTIONS ?= "crd:maxDescLen=0"
KSERVE_ENABLE_SELF_SIGNED_CA ?= false
Expand All @@ -23,7 +24,7 @@ KSERVE_CONTROLLER_MEMORY_LIMIT ?= 300Mi
$(shell perl -pi -e 's/cpu:.*/cpu: $(KSERVE_CONTROLLER_CPU_LIMIT)/' config/default/manager_resources_patch.yaml)
$(shell perl -pi -e 's/memory:.*/memory: $(KSERVE_CONTROLLER_MEMORY_LIMIT)/' config/default/manager_resources_patch.yaml)

all: test manager agent
all: test manager agent router

# Run tests
test: fmt vet manifests envtest
Expand All @@ -37,6 +38,10 @@ manager: generate fmt vet lint
agent: fmt vet
go build -o bin/agent ./cmd/agent

# Build router binary
router: fmt vet
go build -o bin/router ./cmd/router

# Run against the configured Kubernetes cluster in ~/.kube/config
run: generate fmt vet lint
go run ./cmd/manager/main.go
Expand Down Expand Up @@ -90,8 +95,8 @@ deploy-dev-alibi: docker-push-alibi
./hack/alibi_patch_dev.sh ${KO_DOCKER_REPO}/${ALIBI_IMG}
kustomize build config/overlays/dev-image-config | kubectl apply -f -

deploy-dev-aix: docker-push-aix
./hack/aix_patch_dev.sh ${KO_DOCKER_REPO}/${AIX_IMG}
deploy-dev-aix: docker-push-aix
./hack/aix_patch_dev.sh ${KO_DOCKER_REPO}/${AIX_IMG}
kustomize build config/overlays/dev-image-config | kubectl apply -f -

deploy-dev-storageInitializer: docker-push-storageInitializer
Expand All @@ -108,12 +113,14 @@ undeploy:
kustomize build config/default | kubectl delete -f -
kubectl delete validatingwebhookconfigurations.admissionregistration.k8s.io inferenceservice.serving.kserve.io
kubectl delete validatingwebhookconfigurations.admissionregistration.k8s.io trainedmodel.serving.kserve.io
kubectl delete validatingwebhookconfigurations.admissionregistration.k8s.io inferencegraph.serving.kserve.io
kubectl delete mutatingwebhookconfigurations.admissionregistration.k8s.io inferenceservice.serving.kserve.io

undeploy-dev:
kustomize build config/overlays/development | kubectl delete -f -
kubectl delete validatingwebhookconfigurations.admissionregistration.k8s.io inferenceservice.serving.kserve.io
kubectl delete validatingwebhookconfigurations.admissionregistration.k8s.io trainedmodel.serving.kserve.io
kubectl delete validatingwebhookconfigurations.admissionregistration.k8s.io inferencegraph.serving.kserve.io
kubectl delete mutatingwebhookconfigurations.admissionregistration.k8s.io inferenceservice.serving.kserve.io

# Generate manifests e.g. CRD, RBAC etc.
Expand All @@ -130,6 +137,9 @@ manifests: controller-gen
perl -pi -e 's/storedVersions: null/storedVersions: []/g' config/crd/serving.kserve.io_trainedmodels.yaml
perl -pi -e 's/conditions: null/conditions: []/g' config/crd/serving.kserve.io_trainedmodels.yaml
perl -pi -e 's/Any/string/g' config/crd/serving.kserve.io_trainedmodels.yaml
perl -pi -e 's/storedVersions: null/storedVersions: []/g' config/crd/serving.kserve.io_inferencegraphs.yaml
perl -pi -e 's/conditions: null/conditions: []/g' config/crd/serving.kserve.io_inferencegraphs.yaml
perl -pi -e 's/Any/string/g' config/crd/serving.kserve.io_inferencegraphs.yaml
#remove the required property on framework as name field needs to be optional
yq d -i config/crd/serving.kserve.io_inferenceservices.yaml 'spec.versions[0].schema.openAPIV3Schema.properties.spec.properties.*.properties.*.required'
#remove ephemeralContainers properties for compress crd size https://github.com/kubeflow/kfserving/pull/1141#issuecomment-714170602
Expand Down Expand Up @@ -188,9 +198,15 @@ docker-push:
docker-build-agent:
docker build -f agent.Dockerfile . -t ${KO_DOCKER_REPO}/${AGENT_IMG}

docker-build-router:
docker build -f router.Dockerfile . -t ${KO_DOCKER_REPO}/${ROUTER_IMG}

docker-push-agent:
docker push ${KO_DOCKER_REPO}/${AGENT_IMG}

docker-push-router:
docker push ${KO_DOCKER_REPO}/${ROUTER_IMG}

docker-build-sklearn:
cd python && docker build -t ${KO_DOCKER_REPO}/${SKLEARN_IMG} -f sklearn.Dockerfile .

Expand Down
25 changes: 24 additions & 1 deletion cmd/manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/kserve/kserve/pkg/apis/serving/v1alpha1"
"github.com/kserve/kserve/pkg/apis/serving/v1beta1"
"github.com/kserve/kserve/pkg/constants"
graphcontroller "github.com/kserve/kserve/pkg/controller/v1alpha1/inferencegraph"
trainedmodelcontroller "github.com/kserve/kserve/pkg/controller/v1alpha1/trainedmodel"
"github.com/kserve/kserve/pkg/controller/v1alpha1/trainedmodel/reconcilers/modelconfig"
v1beta1controller "github.com/kserve/kserve/pkg/controller/v1beta1/inferenceservice"
Expand All @@ -36,7 +37,7 @@ import (
"k8s.io/client-go/tools/record"
knservingv1 "knative.dev/serving/pkg/apis/serving/v1"
ctrl "sigs.k8s.io/controller-runtime"
client "sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/config"
logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
Expand Down Expand Up @@ -158,6 +159,20 @@ func main() {
os.Exit(1)
}

//Setup Inference graph controller
inferenceGraphEventBroadcaster := record.NewBroadcaster()
setupLog.Info("Setting up InferenceGraph controller")
inferenceGraphEventBroadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: clientSet.CoreV1().Events("")})
if err = (&graphcontroller.InferenceGraphReconciler{
Client: mgr.GetClient(),
Log: ctrl.Log.WithName("v1alpha1Controllers").WithName("InferenceGraph"),
Scheme: mgr.GetScheme(),
Recorder: eventBroadcaster.NewRecorder(mgr.GetScheme(), v1.EventSource{Component: "InferenceGraphController"}),
}).SetupWithManager(mgr); err != nil {
setupLog.Error(err, "unable to create controller", "v1alpha1Controllers", "InferenceGraph")
os.Exit(1)
}

log.Info("setting up webhook server")
hookServer := mgr.GetWebhookServer()

Expand All @@ -170,6 +185,14 @@ func main() {
setupLog.Error(err, "unable to create webhook", "webhook", "v1alpha1")
os.Exit(1)
}

if err = ctrl.NewWebhookManagedBy(mgr).
For(&v1alpha1.InferenceGraph{}).
Complete(); err != nil {
setupLog.Error(err, "unable to create webhook", "webhook", "v1alpha1")
os.Exit(1)
}

if err = ctrl.NewWebhookManagedBy(mgr).
For(&v1beta1.InferenceService{}).
Complete(); err != nil {
Expand Down
206 changes: 206 additions & 0 deletions cmd/router/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
Copyright 2022 The KServe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"strconv"
"time"

"github.com/tidwall/gjson"
logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"

"math/rand"

"github.com/kserve/kserve/pkg/apis/serving/v1alpha1"
flag "github.com/spf13/pflag"
)

var log = logf.Log.WithName("InferenceGraphRouter")

func callService(serviceUrl string, input []byte) ([]byte, error) {
resp, err := http.Post(serviceUrl, "application/json", bytes.NewBuffer(input))
if err != nil {
log.Error(err, "An error has occurred from service", "service", serviceUrl)
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Error(err, "error while reading the response")
}
return body, err
}

func pickupRoute(routes []v1alpha1.InferenceStep) *v1alpha1.InferenceStep {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
//generate num [0,100)
point := r.Intn(99)
end := 0
for _, route := range routes {
end += int(*route.Weight)
if point < end {
return &route
}
}
return nil
}

func pickupRouteByCondition(input []byte, routes []v1alpha1.InferenceStep) *v1alpha1.InferenceStep {
if !gjson.ValidBytes(input) {
return nil
}
for _, route := range routes {
if gjson.GetBytes(input, route.Condition).Exists() {
return &route
}
}
return nil
}

func timeTrack(start time.Time, name string) {
elapsed := time.Since(start)
log.Info("elapsed time", "node", name, "time", elapsed)
}

func routeStep(nodeName string, graph v1alpha1.InferenceGraphSpec, input []byte) ([]byte, error) {
log.Info("current step", "nodeName", nodeName)
defer timeTrack(time.Now(), nodeName)
currentNode := graph.Nodes[nodeName]

if currentNode.RouterType == v1alpha1.Splitter {
return executeStep(pickupRoute(currentNode.Steps), graph, input)
}
if currentNode.RouterType == v1alpha1.Switch {
route := pickupRouteByCondition(input, currentNode.Steps)
if route == nil {
return input, nil //TODO maybe should fail in this case?
}
return executeStep(route, graph, input)
}
if currentNode.RouterType == v1alpha1.Ensemble {
ensembleRes := make([]chan map[string]interface{}, len(currentNode.Steps))
errChan := make(chan error)
for i := range currentNode.Steps {
step := &currentNode.Steps[i]
resultChan := make(chan map[string]interface{})
ensembleRes[i] = resultChan
go func() {
output, err := executeStep(step, graph, input)
if err == nil {
var res map[string]interface{}
if err = json.Unmarshal(output, &res); err == nil {
resultChan <- res
return
}
}
errChan <- err
}()
}
// merge responses from parallel steps
response := map[string]interface{}{}
for i, resultChan := range ensembleRes {
key := currentNode.Steps[i].StepName
if key == "" {
key = strconv.Itoa(i) // Use index if no step name
}
select {
case response[key] = <-resultChan:
case err := <-errChan:
return nil, err
}
}
return json.Marshal(response)
}
if currentNode.RouterType == v1alpha1.Sequence {
var responseBytes []byte
var err error
for i := range currentNode.Steps {
step := &currentNode.Steps[i]
request := input
if step.Data == "$response" && i > 0 {
request = responseBytes
}

if step.Condition != "" {
if !gjson.ValidBytes(responseBytes) {
return nil, fmt.Errorf("invalid response")
}
// if the condition does not match for the step in the sequence we stop and return the response
if !gjson.GetBytes(responseBytes, step.Condition).Exists() {
return responseBytes, nil
}
}
if responseBytes, err = executeStep(step, graph, request); err != nil {
return nil, err
}
}
return responseBytes, nil
}
log.Error(nil, "invalid route type", "type", currentNode.RouterType)
return nil, fmt.Errorf("invalid route type: %v", currentNode.RouterType)
}

func executeStep(step *v1alpha1.InferenceStep, graph v1alpha1.InferenceGraphSpec, input []byte) ([]byte, error) {
if step.NodeName != "" {
// when nodeName is specified make a recursive call for routing to next step
return routeStep(step.NodeName, graph, input)
}
return callService(step.ServiceURL, input)
}

var inferenceGraph *v1alpha1.InferenceGraphSpec

func graphHandler(w http.ResponseWriter, req *http.Request) {
inputBytes, _ := ioutil.ReadAll(req.Body)
if response, err := routeStep(v1alpha1.GraphRootNodeName, *inferenceGraph, inputBytes); err != nil {
log.Error(err, "failed to process request")
w.WriteHeader(500) //TODO status code tbd
w.Write([]byte(fmt.Sprintf("Failed to process request: %v", err)))
} else {
w.Write(response)
}
}

var (
jsonGraph = flag.String("graph-json", "", "serialized json graph def")
)

func main() {
flag.Parse()
logf.SetLogger(zap.New())
inferenceGraph = &v1alpha1.InferenceGraphSpec{}
err := json.Unmarshal([]byte(*jsonGraph), inferenceGraph)
if err != nil {
log.Error(err, "failed to unmarshall inference graph json")
os.Exit(1)
}

http.HandleFunc("/", graphHandler)

err = http.ListenAndServe(":8080", nil)
if err != nil {
log.Error(err, "failed to listen on 8080")
os.Exit(1)
}
}
Loading

0 comments on commit 6154ffd

Please sign in to comment.