forked from kubeflow/pipelines
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial inference graph API implementation (kubeflow#1910)
* Add inference graph API Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Implement inference graph API Signed-off-by: Dan Sun <dsun20@bloomberg.net> * add validation and splitter Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Fix inference graph status Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Refactor knative reconciler Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Resolve service urls Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Improve logging Signed-off-by: Dan Sun <dsun20@bloomberg.net> * add switch and root node configuration. Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Update docs/samples/graph/README.md Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Add inference router dockerfile Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Fix comments and error handling Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Updating single routerType to sequence for more flexisble composition Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Some InferenceGraph code streamlining - Avoid unnecessary goroutine use - Avoid most unnecessary unmarshal+remarshals - Deduplicate logic between router types - Handle async errors in ensemble case - Don't require stepName to be set - Return http error response when applicable - Simplify pickupRoute func - Fix incorrect log parameter syntax - Add webhook validation for unique step names and inference targets Signed-off-by: Nick Hill <nickhill@us.ibm.com> Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Fix inference graph spec Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Add condition tests Signed-off-by: Dan Sun <dsun20@bloomberg.net> * Update graph doc Signed-off-by: Dan Sun <dsun20@bloomberg.net> * rename service to serviceName Signed-off-by: Dan Sun <dsun20@bloomberg.net> Co-authored-by: iamlovingit <bitfrog@163.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
- Loading branch information
1 parent
9b720ae
commit 6154ffd
Showing
36 changed files
with
2,484 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
/* | ||
Copyright 2022 The KServe Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package main | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"fmt" | ||
"io/ioutil" | ||
"net/http" | ||
"os" | ||
"strconv" | ||
"time" | ||
|
||
"github.com/tidwall/gjson" | ||
logf "sigs.k8s.io/controller-runtime/pkg/log" | ||
"sigs.k8s.io/controller-runtime/pkg/log/zap" | ||
|
||
"math/rand" | ||
|
||
"github.com/kserve/kserve/pkg/apis/serving/v1alpha1" | ||
flag "github.com/spf13/pflag" | ||
) | ||
|
||
var log = logf.Log.WithName("InferenceGraphRouter") | ||
|
||
func callService(serviceUrl string, input []byte) ([]byte, error) { | ||
resp, err := http.Post(serviceUrl, "application/json", bytes.NewBuffer(input)) | ||
if err != nil { | ||
log.Error(err, "An error has occurred from service", "service", serviceUrl) | ||
return nil, err | ||
} | ||
defer resp.Body.Close() | ||
body, err := ioutil.ReadAll(resp.Body) | ||
if err != nil { | ||
log.Error(err, "error while reading the response") | ||
} | ||
return body, err | ||
} | ||
|
||
func pickupRoute(routes []v1alpha1.InferenceStep) *v1alpha1.InferenceStep { | ||
r := rand.New(rand.NewSource(time.Now().UnixNano())) | ||
//generate num [0,100) | ||
point := r.Intn(99) | ||
end := 0 | ||
for _, route := range routes { | ||
end += int(*route.Weight) | ||
if point < end { | ||
return &route | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func pickupRouteByCondition(input []byte, routes []v1alpha1.InferenceStep) *v1alpha1.InferenceStep { | ||
if !gjson.ValidBytes(input) { | ||
return nil | ||
} | ||
for _, route := range routes { | ||
if gjson.GetBytes(input, route.Condition).Exists() { | ||
return &route | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func timeTrack(start time.Time, name string) { | ||
elapsed := time.Since(start) | ||
log.Info("elapsed time", "node", name, "time", elapsed) | ||
} | ||
|
||
func routeStep(nodeName string, graph v1alpha1.InferenceGraphSpec, input []byte) ([]byte, error) { | ||
log.Info("current step", "nodeName", nodeName) | ||
defer timeTrack(time.Now(), nodeName) | ||
currentNode := graph.Nodes[nodeName] | ||
|
||
if currentNode.RouterType == v1alpha1.Splitter { | ||
return executeStep(pickupRoute(currentNode.Steps), graph, input) | ||
} | ||
if currentNode.RouterType == v1alpha1.Switch { | ||
route := pickupRouteByCondition(input, currentNode.Steps) | ||
if route == nil { | ||
return input, nil //TODO maybe should fail in this case? | ||
} | ||
return executeStep(route, graph, input) | ||
} | ||
if currentNode.RouterType == v1alpha1.Ensemble { | ||
ensembleRes := make([]chan map[string]interface{}, len(currentNode.Steps)) | ||
errChan := make(chan error) | ||
for i := range currentNode.Steps { | ||
step := ¤tNode.Steps[i] | ||
resultChan := make(chan map[string]interface{}) | ||
ensembleRes[i] = resultChan | ||
go func() { | ||
output, err := executeStep(step, graph, input) | ||
if err == nil { | ||
var res map[string]interface{} | ||
if err = json.Unmarshal(output, &res); err == nil { | ||
resultChan <- res | ||
return | ||
} | ||
} | ||
errChan <- err | ||
}() | ||
} | ||
// merge responses from parallel steps | ||
response := map[string]interface{}{} | ||
for i, resultChan := range ensembleRes { | ||
key := currentNode.Steps[i].StepName | ||
if key == "" { | ||
key = strconv.Itoa(i) // Use index if no step name | ||
} | ||
select { | ||
case response[key] = <-resultChan: | ||
case err := <-errChan: | ||
return nil, err | ||
} | ||
} | ||
return json.Marshal(response) | ||
} | ||
if currentNode.RouterType == v1alpha1.Sequence { | ||
var responseBytes []byte | ||
var err error | ||
for i := range currentNode.Steps { | ||
step := ¤tNode.Steps[i] | ||
request := input | ||
if step.Data == "$response" && i > 0 { | ||
request = responseBytes | ||
} | ||
|
||
if step.Condition != "" { | ||
if !gjson.ValidBytes(responseBytes) { | ||
return nil, fmt.Errorf("invalid response") | ||
} | ||
// if the condition does not match for the step in the sequence we stop and return the response | ||
if !gjson.GetBytes(responseBytes, step.Condition).Exists() { | ||
return responseBytes, nil | ||
} | ||
} | ||
if responseBytes, err = executeStep(step, graph, request); err != nil { | ||
return nil, err | ||
} | ||
} | ||
return responseBytes, nil | ||
} | ||
log.Error(nil, "invalid route type", "type", currentNode.RouterType) | ||
return nil, fmt.Errorf("invalid route type: %v", currentNode.RouterType) | ||
} | ||
|
||
func executeStep(step *v1alpha1.InferenceStep, graph v1alpha1.InferenceGraphSpec, input []byte) ([]byte, error) { | ||
if step.NodeName != "" { | ||
// when nodeName is specified make a recursive call for routing to next step | ||
return routeStep(step.NodeName, graph, input) | ||
} | ||
return callService(step.ServiceURL, input) | ||
} | ||
|
||
var inferenceGraph *v1alpha1.InferenceGraphSpec | ||
|
||
func graphHandler(w http.ResponseWriter, req *http.Request) { | ||
inputBytes, _ := ioutil.ReadAll(req.Body) | ||
if response, err := routeStep(v1alpha1.GraphRootNodeName, *inferenceGraph, inputBytes); err != nil { | ||
log.Error(err, "failed to process request") | ||
w.WriteHeader(500) //TODO status code tbd | ||
w.Write([]byte(fmt.Sprintf("Failed to process request: %v", err))) | ||
} else { | ||
w.Write(response) | ||
} | ||
} | ||
|
||
var ( | ||
jsonGraph = flag.String("graph-json", "", "serialized json graph def") | ||
) | ||
|
||
func main() { | ||
flag.Parse() | ||
logf.SetLogger(zap.New()) | ||
inferenceGraph = &v1alpha1.InferenceGraphSpec{} | ||
err := json.Unmarshal([]byte(*jsonGraph), inferenceGraph) | ||
if err != nil { | ||
log.Error(err, "failed to unmarshall inference graph json") | ||
os.Exit(1) | ||
} | ||
|
||
http.HandleFunc("/", graphHandler) | ||
|
||
err = http.ListenAndServe(":8080", nil) | ||
if err != nil { | ||
log.Error(err, "failed to listen on 8080") | ||
os.Exit(1) | ||
} | ||
} |
Oops, something went wrong.