Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move pipeline runner service account to backend #1988

Merged
merged 7 commits into from
Aug 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/src/apiserver/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go_library(
name = "go_default_library",
srcs = [
"client_manager.go",
"config.go",
"interceptor.go",
"main.go",
],
Expand All @@ -13,6 +12,7 @@ go_library(
deps = [
"//backend/api:go_default_library",
"//backend/src/apiserver/client:go_default_library",
"//backend/src/apiserver/common:go_default_library",
"//backend/src/apiserver/model:go_default_library",
"//backend/src/apiserver/resource:go_default_library",
"//backend/src/apiserver/server:go_default_library",
Expand Down
43 changes: 23 additions & 20 deletions backend/src/apiserver/client_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package main
import (
"database/sql"
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
v1 "k8s.io/client-go/kubernetes/typed/core/v1"
"os"
"time"
Expand All @@ -38,17 +39,18 @@ import (
const (
minioServiceHost = "MINIO_SERVICE_SERVICE_HOST"
minioServicePort = "MINIO_SERVICE_SERVICE_PORT"

mysqlServiceHost = "MYSQL_SERVICE_HOST"
mysqlServicePort = "MYSQL_SERVICE_PORT"
mysqlUser = "DBConfig.User"
mysqlPassword = "DBConfig.Password"
mysqlDBName = "DBConfig.DBName"

podNamespace = "POD_NAMESPACE"
initConnectionTimeout = "InitConnectionTimeout"

visualizationServiceHost = "ML_PIPELINE_VISUALIZATIONSERVER_SERVICE_HOST"
visualizationServicePort = "ML_PIPELINE_VISUALIZATIONSERVER_SERVICE_PORT"

podNamespace = "POD_NAMESPACE"
initConnectionTimeout = "InitConnectionTimeout"
)

// Container for all service clients
Expand All @@ -67,6 +69,7 @@ type ClientManager struct {
podClient v1.PodInterface
time util.TimeInterface
uuid util.UUIDGeneratorInterface

}

func (c *ClientManager) ExperimentStore() storage.ExperimentStoreInterface {
Expand Down Expand Up @@ -124,7 +127,7 @@ func (c *ClientManager) UUID() util.UUIDGeneratorInterface {
func (c *ClientManager) init() {
glog.Infof("Initializing client manager")

db := initDBClient(getDurationConfig(initConnectionTimeout))
db := initDBClient(common.GetDurationConfig(initConnectionTimeout))

// time
c.time = util.NewRealTime()
Expand All @@ -139,16 +142,16 @@ func (c *ClientManager) init() {
c.resourceReferenceStore = storage.NewResourceReferenceStore(db)
c.dBStatusStore = storage.NewDBStatusStore(db)
c.defaultExperimentStore = storage.NewDefaultExperimentStore(db)
c.objectStore = initMinioClient(getDurationConfig(initConnectionTimeout))
c.objectStore = initMinioClient(common.GetDurationConfig(initConnectionTimeout))

c.wfClient = client.CreateWorkflowClientOrFatal(
getStringConfig(podNamespace), getDurationConfig(initConnectionTimeout))
common.GetStringConfig(podNamespace), common.GetDurationConfig(initConnectionTimeout))

c.swfClient = client.CreateScheduledWorkflowClientOrFatal(
getStringConfig(podNamespace), getDurationConfig(initConnectionTimeout))
common.GetStringConfig(podNamespace), common.GetDurationConfig(initConnectionTimeout))

c.podClient = client.CreatePodClientOrFatal(
getStringConfig(podNamespace), getDurationConfig(initConnectionTimeout))
common.GetStringConfig(podNamespace), common.GetDurationConfig(initConnectionTimeout))

runStore := storage.NewRunStore(db, c.time)
c.runStore = runStore
Expand All @@ -161,7 +164,7 @@ func (c *ClientManager) Close() {
}

func initDBClient(initConnectionTimeout time.Duration) *storage.DB {
driverName := getStringConfig("DBConfig.DriverName")
driverName := common.GetStringConfig("DBConfig.DriverName")
var arg string

switch driverName {
Expand Down Expand Up @@ -208,10 +211,10 @@ func initDBClient(initConnectionTimeout time.Duration) *storage.DB {
// Format would be something like root@tcp(ip:port)/dbname?charset=utf8&loc=Local&parseTime=True
func initMysql(driverName string, initConnectionTimeout time.Duration) string {
mysqlConfig := client.CreateMySQLConfig(
getStringConfigWithDefault(mysqlUser, "root"),
getStringConfigWithDefault(mysqlPassword, ""),
getStringConfig(mysqlServiceHost),
getStringConfig(mysqlServicePort),
common.GetStringConfigWithDefault(mysqlUser, "root"),
common.GetStringConfigWithDefault(mysqlPassword, ""),
common.GetStringConfig(mysqlServiceHost),
common.GetStringConfig(mysqlServicePort),
"")

var db *sql.DB
Expand All @@ -231,7 +234,7 @@ func initMysql(driverName string, initConnectionTimeout time.Duration) string {
util.TerminateIfError(err)

// Create database if not exist
dbName := getStringConfig(mysqlDBName)
dbName := common.GetStringConfig(mysqlDBName)
operation = func() error {
_, err = db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbName))
if err != nil {
Expand All @@ -250,14 +253,14 @@ func initMysql(driverName string, initConnectionTimeout time.Duration) string {

func initMinioClient(initConnectionTimeout time.Duration) storage.ObjectStoreInterface {
// Create minio client.
minioServiceHost := getStringConfigWithDefault(
minioServiceHost := common.GetStringConfigWithDefault(
"ObjectStoreConfig.Host", os.Getenv(minioServiceHost))
minioServicePort := getStringConfigWithDefault(
minioServicePort := common.GetStringConfigWithDefault(
"ObjectStoreConfig.Port", os.Getenv(minioServicePort))
accessKey := getStringConfig("ObjectStoreConfig.AccessKey")
secretKey := getStringConfig("ObjectStoreConfig.SecretAccessKey")
bucketName := getStringConfig("ObjectStoreConfig.BucketName")
disableMultipart := getBoolConfigWithDefault("ObjectStoreConfig.Multipart.Disable", true)
accessKey := common.GetStringConfig("ObjectStoreConfig.AccessKey")
secretKey := common.GetStringConfig("ObjectStoreConfig.SecretAccessKey")
bucketName := common.GetStringConfig("ObjectStoreConfig.BucketName")
disableMultipart := common.GetBoolConfigWithDefault("ObjectStoreConfig.Multipart.Disable", true)

minioClient := client.CreateMinioClientOrFatal(minioServiceHost, minioServicePort, accessKey,
secretKey, initConnectionTimeout)
Expand Down
3 changes: 3 additions & 0 deletions backend/src/apiserver/common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"config.go",
"const.go",
"filter_context.go",
"pagination_context.go",
Expand All @@ -13,5 +14,7 @@ go_library(
deps = [
"//backend/api:go_default_library",
"//backend/src/common/util:go_default_library",
"@com_github_golang_glog//:go_default_library",
"@com_github_spf13_viper//:go_default_library",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,55 +12,32 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package main
package common

import (
"strconv"
"strings"
"time"

"github.com/fsnotify/fsnotify"
"github.com/golang/glog"
"github.com/spf13/viper"
)

func initConfig() {
// Import environment variable, support nested vars e.g. OBJECTSTORECONFIG_ACCESSKEY
replacer := strings.NewReplacer(".", "_")
viper.SetEnvKeyReplacer(replacer)
viper.AutomaticEnv()

// Set configuration file name. The format is auto detected in this case.
viper.SetConfigName("config")
viper.AddConfigPath(*configPath)
err := viper.ReadInConfig()
if err != nil {
glog.Fatalf("Fatal error config file: %s", err)
}

// Watch for configuration change
viper.WatchConfig()
viper.OnConfigChange(func(e fsnotify.Event) {
// Read in config again
viper.ReadInConfig()
})
}

func getStringConfig(configName string) string {
func GetStringConfig(configName string) string {
if !viper.IsSet(configName) {
glog.Fatalf("Please specify flag %s", configName)
}
return viper.GetString(configName)
}

func getStringConfigWithDefault(configName, value string) string {
func GetStringConfigWithDefault(configName, value string) string {
if !viper.IsSet(configName) {
return value
}
return viper.GetString(configName)
}

func getBoolConfigWithDefault(configName string, value bool) bool {
func GetBoolConfigWithDefault(configName string, value bool) bool {
if !viper.IsSet(configName) {
return value
}
Expand All @@ -71,7 +48,7 @@ func getBoolConfigWithDefault(configName string, value bool) bool {
return value
}

func getDurationConfig(configName string) time.Duration {
func GetDurationConfig(configName string) time.Duration {
if !viper.IsSet(configName) {
glog.Fatalf("Please specify flag %s", configName)
}
Expand Down
3 changes: 2 additions & 1 deletion backend/src/apiserver/config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
"SecretAccessKey": "minio123",
"BucketName": "mlpipeline"
},
"InitConnectionTimeout": "6m"
"InitConnectionTimeout": "6m",
"DefaultPipelineRunnerServiceAccount": "pipeline-runner"
hongye-sun marked this conversation as resolved.
Show resolved Hide resolved
}
34 changes: 30 additions & 4 deletions backend/src/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ import (
"context"
"encoding/json"
"flag"
"github.com/fsnotify/fsnotify"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
"github.com/spf13/viper"
"io"
"io/ioutil"
"net"
"net/http"
"strings"
"time"

"fmt"
Expand Down Expand Up @@ -83,9 +87,9 @@ func startRpcServer(resourceManager *resource.ResourceManager) {
s,
server.NewVisualizationServer(
resourceManager,
getStringConfig(visualizationServiceHost),
getStringConfig(visualizationServicePort),
getDurationConfig(initConnectionTimeout),
common.GetStringConfig(visualizationServiceHost),
common.GetStringConfig(visualizationServicePort),
common.GetDurationConfig(initConnectionTimeout),
))

// Register reflection service on gRPC server.
Expand Down Expand Up @@ -121,7 +125,7 @@ func startHttpProxy(resourceManager *resource.ResourceManager) {
pipelineUploadServer := server.NewPipelineUploadServer(resourceManager)
topMux.HandleFunc("/apis/v1beta1/pipelines/upload", pipelineUploadServer.UploadPipeline)
topMux.HandleFunc("/apis/v1beta1/healthz", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, `{"commit_sha":"`+getStringConfig("COMMIT_SHA")+`"}`)
io.WriteString(w, `{"commit_sha":"`+common.GetStringConfig("COMMIT_SHA")+`"}`)
})

topMux.Handle("/apis/", mux)
Expand Down Expand Up @@ -195,3 +199,25 @@ func loadSamples(resourceManager *resource.ResourceManager) error {
glog.Info("All samples are loaded.")
return nil
}

func initConfig() {
// Import environment variable, support nested vars e.g. OBJECTSTORECONFIG_ACCESSKEY
replacer := strings.NewReplacer(".", "_")
viper.SetEnvKeyReplacer(replacer)
viper.AutomaticEnv()

// Set configuration file name. The format is auto detected in this case.
viper.SetConfigName("config")
viper.AddConfigPath(*configPath)
err := viper.ReadInConfig()
if err != nil {
glog.Fatalf("Fatal error config file: %s", err)
}

// Watch for configuration change
viper.WatchConfig()
viper.OnConfigChange(func(e fsnotify.Event) {
// Read in config again
viper.ReadInConfig()
})
}
18 changes: 16 additions & 2 deletions backend/src/apiserver/resource/resource_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ import (
"k8s.io/apimachinery/pkg/types"
)

const (
defaultPipelineRunnerServiceAccountEnvVar = "DefaultPipelineRunnerServiceAccount"
defaultPipelineRunnerServiceAccount = "pipeline-runner"
)

type ClientManagerInterface interface {
ExperimentStore() storage.ExperimentStoreInterface
PipelineStore() storage.PipelineStoreInterface
Expand Down Expand Up @@ -219,6 +224,8 @@ func (r *ResourceManager) CreateRun(apiRun *api.Run) (*model.RunDetail, error) {
if err = workflow.VerifyParameters(parameters); err != nil {
return nil, util.Wrap(err, "Failed to verify parameters.")
}

workflow.SetServiceAccount(r.getDefaultSA())
// Append provided parameter
workflow.OverrideParameters(parameters)
// Add label to the workflow so it can be persisted by persistent agent later.
Expand Down Expand Up @@ -428,6 +435,9 @@ func (r *ResourceManager) CreateJob(apiJob *api.Job) (*model.Job, error) {
return nil, util.Wrap(err, "Create job failed")
}

// Set workflow to be run using default pipeline runner service account.
workflow.SetServiceAccount(r.getDefaultSA())

scheduledWorkflow := &scheduledworkflow.ScheduledWorkflow{
ObjectMeta: v1.ObjectMeta{GenerateName: swfGeneratedName},
Spec: scheduledworkflow.ScheduledWorkflowSpec{
Expand Down Expand Up @@ -673,7 +683,7 @@ func (r *ResourceManager) CreateDefaultExperiment() (string, error) {
}
// If default experiment ID is already present, don't fail, simply return.
if defaultExperimentId != "" {
glog.Info("Default experiment already exists! ID: %v", defaultExperimentId)
glog.Infof("Default experiment already exists! ID: %v", defaultExperimentId)
return "", nil
}

Expand All @@ -693,7 +703,7 @@ func (r *ResourceManager) CreateDefaultExperiment() (string, error) {
return "", fmt.Errorf("Failed to set default experiment ID. Err: %v", err)
}

glog.Info("Default experiment is set. ID is: %v", experiment.UUID)
glog.Infof("Default experiment is set. ID is: %v", experiment.UUID)
return experiment.UUID, nil
}

Expand Down Expand Up @@ -772,3 +782,7 @@ func (r *ResourceManager) HaveSamplesLoaded() (bool, error) {
func (r *ResourceManager) MarkSampleLoaded() error {
return r.dBStatusStore.MarkSampleLoaded()
}

func (r *ResourceManager) getDefaultSA() string{
return common.GetStringConfigWithDefault(defaultPipelineRunnerServiceAccountEnvVar, defaultPipelineRunnerServiceAccount)
}
2 changes: 2 additions & 0 deletions backend/src/apiserver/resource/resource_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ func TestCreateRun_ThroughPipelineID(t *testing.T) {
expectedRuntimeWorkflow.Spec.Arguments.Parameters = []v1alpha1.Parameter{
{Name: "param1", Value: util.StringPointer("world")}}
expectedRuntimeWorkflow.Labels = map[string]string{util.LabelKeyWorkflowRunId: "123e4567-e89b-12d3-a456-426655440000"}
expectedRuntimeWorkflow.Spec.ServiceAccountName = defaultPipelineRunnerServiceAccount

expectedRunDetail := &model.RunDetail{
Run: model.Run{
Expand Down Expand Up @@ -326,6 +327,7 @@ func TestCreateRun_ThroughWorkflowSpec(t *testing.T) {
expectedRuntimeWorkflow.Spec.Arguments.Parameters = []v1alpha1.Parameter{
{Name: "param1", Value: util.StringPointer("world")}}
expectedRuntimeWorkflow.Labels = map[string]string{util.LabelKeyWorkflowRunId: "123e4567-e89b-12d3-a456-426655440000"}
expectedRuntimeWorkflow.Spec.ServiceAccountName = defaultPipelineRunnerServiceAccount
expectedRunDetail := &model.RunDetail{
Run: model.Run{
UUID: "123e4567-e89b-12d3-a456-426655440000",
Expand Down
1 change: 1 addition & 0 deletions backend/src/apiserver/server/run_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func TestCreateRun(t *testing.T) {
expectedRuntimeWorkflow.Spec.Arguments.Parameters = []v1alpha1.Parameter{
{Name: "param1", Value: util.StringPointer("world")}}
expectedRuntimeWorkflow.Labels = map[string]string{util.LabelKeyWorkflowRunId: "123e4567-e89b-12d3-a456-426655440000"}
expectedRuntimeWorkflow.Spec.ServiceAccountName = "pipeline-runner"
expectedRunDetail := api.RunDetail{
Run: &api.Run{
Id: "123e4567-e89b-12d3-a456-426655440000",
Expand Down
Loading