Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e843858
init etcd cclient
jacquesqiao Jun 21, 2017
22c621b
add etcd
jacquesqiao Jun 22, 2017
4e03120
add etcd.go
jacquesqiao Jun 22, 2017
ca4626b
fix compile problem
jacquesqiao Jun 22, 2017
a535e01
move code to etcd.go
jacquesqiao Jun 23, 2017
17bdd60
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jun 25, 2017
9f1d0bf
add etcd_lister.go for pserver client
jacquesqiao Jun 26, 2017
73b9d98
add etcd_client_test.go
jacquesqiao Jun 26, 2017
cffaa4d
merge etcd_client_test and client_test
jacquesqiao Jun 26, 2017
5385980
refine client_test.go
jacquesqiao Jun 26, 2017
2db99f1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jun 29, 2017
05d7e01
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jun 30, 2017
6131e70
refine code
jacquesqiao Jun 30, 2017
90f5aec
format code
jacquesqiao Jun 30, 2017
78d6a02
add TODO and use interface instead of struct
jacquesqiao Jun 30, 2017
b897acc
fix typo of initDesiredPservers
jacquesqiao Jun 30, 2017
5635409
optimize dir structure of go/pserver/client
jacquesqiao Jul 1, 2017
8bcd01d
add a flag to config index for pserver
jacquesqiao Jul 1, 2017
475005f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 1, 2017
5d6dc69
follow comment
jacquesqiao Jul 3, 2017
e767db8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 4, 2017
c1ceef6
fix path
jacquesqiao Jul 4, 2017
ec61fdb
optimize code
jacquesqiao Jul 4, 2017
41c299c
remove err in pserver NewEtcd
jacquesqiao Jul 4, 2017
9e71bae
restore comment about /ps_desired
jacquesqiao Jul 4, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ include(coveralls) # set code coverage
include_directories("${PROJ_ROOT}")
include_directories("${PROJ_ROOT}/paddle/cuda/include")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c")
include_directories(${Boost_INCLUDE_DIRS})

set(EXTERNAL_LIBS
Expand Down
2 changes: 1 addition & 1 deletion go/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
#

add_subdirectory(pserver/cclient)
add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master)
add_subdirectory(master/c)
16 changes: 11 additions & 5 deletions go/cmd/pserver/pserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

func main() {
port := flag.Int("port", 0, "port of the pserver")
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
Expand All @@ -29,11 +30,16 @@ func main() {
}
log.SetLevel(level)

timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err := e.Register()
if err != nil {
panic(err)
var idx int
if *index >= 0 {
idx = *index
} else {
timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register()
if err != nil {
panic(err)
}
}

s, err := pserver.NewService(idx)
Expand Down
4 changes: 2 additions & 2 deletions go/master/etcd_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
lock := concurrency.NewMutex(sess, lockPath)
// It's fine for the lock to get stuck, in this case we have
// multiple master servers running (only configured to have
// one master running, but split-brain problem may cuase
// one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management
// software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath)
Expand Down Expand Up @@ -98,7 +98,7 @@ func (e *EtcdClient) Save(state []byte) error {
// We lost the master lock and can not acquire
// it back, it means some other master is
// already started. We don't want cluster
// managment system to kill the master server
// management system to kill the master server
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have so many typos :p. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)

// who is holding the lock and running
// correctly. So the most feasible solution is
// to kill current master server. The current
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
go_library(paddle_pserver_cclient STATIC)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING)
add_subdirectory(test)
endif()
26 changes: 15 additions & 11 deletions go/pserver/cclient/cclient.go → go/pserver/client/c/cclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ import (
"unsafe"

"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
log "github.com/sirupsen/logrus"
)

var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*pserver.Client)
var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client

func add(c *pserver.Client) C.paddle_pserver_client {
func add(c *client.Client) C.paddle_pserver_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
Expand All @@ -47,13 +48,13 @@ func add(c *pserver.Client) C.paddle_pserver_client {
return client
}

func get(client C.paddle_pserver_client) *pserver.Client {
func get(client C.paddle_pserver_client) *client.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}

func remove(client C.paddle_pserver_client) *pserver.Client {
func remove(client C.paddle_pserver_client) *client.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
Expand All @@ -80,29 +81,32 @@ func (s selector) Select() bool {
return bool(s)
}

type lister []pserver.Server
type lister []client.Server

func (l lister) List() []pserver.Server {
func (l lister) List() []client.Server {
return l
}

//export paddle_new_pserver_client
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
a := C.GoString(addrs)
as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as))
servers := make([]client.Server, len(as))
for i := range as {
servers[i].Index = i
servers[i].Addr = as[i]
}
c := pserver.NewClient(lister(servers), len(as), selector(selected != 0))
c := client.NewClient(lister(servers), len(as), selector(selected != 0))
return add(c)
}

//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client {
// TODO(helin): fault tolerant pserver client using etcd.
panic("not implemented.")
func paddle_new_etcd_pserver_client(etcd_endpoints *C.char, selected int) C.paddle_pserver_client {
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
addr := C.GoString(etcd_endpoints)
etcd_client := client.NewEtcd(addr)
c := client.NewClient(etcd_client, etcd_client.Desired(), selector(selected != 0))
return add(c)
}

//export paddle_pserver_client_release
Expand Down
17 changes: 9 additions & 8 deletions go/pserver/client.go → go/pserver/client/client.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package pserver
package client

import (
"errors"
Expand All @@ -7,6 +7,7 @@ import (
"time"

"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -105,7 +106,7 @@ func (c *Client) BeginInitParams() bool {
}

// InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
func (c *Client) InitParam(paramWithConfigs pserver.ParameterWithConfig) error {
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil)
}

Expand All @@ -123,13 +124,13 @@ func (c *Client) FinishInitParams() error {

// SendGrads sends gradients to parameter servers for updating
// parameters.
func (c *Client) SendGrads(grads []Gradient) error {
func (c *Client) SendGrads(grads []pserver.Gradient) error {
if len(grads) == 0 {
return errors.New("no gradient received")
}
errCh := make(chan error, len(grads))
for _, g := range grads {
go func(g Gradient) {
go func(g pserver.Gradient) {
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil)
errCh <- err
}(g)
Expand All @@ -151,7 +152,7 @@ func (c *Client) SendGrads(grads []Gradient) error {

type result struct {
idx int
param Parameter
param pserver.Parameter
err error
}

Expand All @@ -170,12 +171,12 @@ func (r results) Swap(i int, j int) {
}

// GetParams gets parameters from parameter servers.
func (c *Client) GetParams(names []string) ([]Parameter, error) {
func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
rCh := make(chan result, len(names))

for idx, name := range names {
go func(name string, idx int) {
var parameter Parameter
var parameter pserver.Parameter
err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
rCh <- result{idx: idx, param: parameter, err: err}
}(name, idx)
Expand All @@ -196,7 +197,7 @@ func (c *Client) GetParams(names []string) ([]Parameter, error) {
}
sort.Sort(rs)

ps := make([]Parameter, len(rs))
ps := make([]pserver.Parameter, len(rs))
for i := range rs {
ps[i] = rs[i].param
}
Expand Down
77 changes: 63 additions & 14 deletions go/pserver/client_test.go → go/pserver/client/client_test.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
package pserver_test
package client_test

import (
"context"
"io/ioutil"
"net"
"net/http"
"net/rpc"
"strconv"
"strings"
"testing"
"time"

"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)

const numPserver = 10
const (
numPserver = 10
etcdEndpoints = "127.0.0.1:2379"
timeout = 2 * time.Second
)

var port [numPserver]int
var pserverClientPorts [numPserver]int

func init() {
// this function init pserver client and return their ports in an array.
func initClient() [numPserver]int {
var ports [numPserver]int
for i := 0; i < numPserver; i++ {
l, err := net.Listen("tcp", ":0")
if err != nil {
Expand All @@ -28,7 +39,7 @@ func init() {
if err != nil {
panic(err)
}
port[i] = p
ports[i] = p

go func(l net.Listener) {
s, err := pserver.NewService(0)
Expand All @@ -49,6 +60,31 @@ func init() {
}
}(l)
}
return ports
}

func initNativeClient() {
pserverClientPorts = initClient()
}

func initEtcdClient() {
client, err := clientv3.New(clientv3.Config{
Endpoints: []string{etcdEndpoints},
DialTimeout: time.Second * time.Duration(1),
})
if err != nil {
log.Errorf("err %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
client.Delete(ctx, pserver.PsDesired)
client.Delete(ctx, pserver.PsPath)
client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
ports := initClient()
for i := 0; i < numPserver; i++ {
client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
}
cancel()
client.Close()
}

type selector bool
Expand All @@ -57,25 +93,20 @@ func (s selector) Select() bool {
return bool(s)
}

type lister []pserver.Server
type lister []client.Server

func (l lister) List() []pserver.Server {
func (l lister) List() []client.Server {
return l
}

func TestClientFull(t *testing.T) {
servers := make([]pserver.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
}
c := pserver.NewClient(lister(servers), len(servers), selector(true))
func ClientTest(t *testing.T, c *client.Client) {
selected := c.BeginInitParams()
if !selected {
t.Fatal("should be selected.")
}

const numParameter = 100
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
Expand Down Expand Up @@ -129,3 +160,21 @@ func TestClientFull(t *testing.T) {
}
}
}

func TestNativeClient(t *testing.T) {
initNativeClient()
servers := make([]client.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
}
c1 := client.NewClient(lister(servers), len(servers), selector(true))
ClientTest(t, c1)
}

// TODO: tmperary disable etcdClient test for dependency of etcd)
func EtcdClient(t *testing.T) {
initEtcdClient()
etcd_client := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true))
ClientTest(t, c2)
}
Loading