Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions go/master/c/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package main

/*
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1

typedef int paddle_master_client;
*/
import "C"

import (
"log"
"sync"
"unsafe"

"github.com/PaddlePaddle/Paddle/go/master"
)

var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client

func add(c *master.Client) C.paddle_master_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
curHandle++
handleMap[client] = c
return client
}

func get(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}

func remove(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
delete(handleMap, client)
return h
}

type addresser string

func (a addresser) Address() string {
return string(a)
}

//export paddle_new_master_client
func paddle_new_master_client(addr *C.char) C.paddle_master_client {
a := C.GoString(addr)
c := master.NewClient(addresser(a))
return add(c)
}

//export paddle_release_master_client
func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}

//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
var paths []string
for i := 0; i < int(size); i++ {
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
str := C.GoString(*ptr)
paths = append(paths, str)
}
err := c.SetDataset(paths)
if err != nil {
log.Println(err)
return C.PADDLE_MASTER_ERROR
}

return C.PADDLE_MASTER_OK
}

//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r := c.NextRecord()
if len(r) == 0 {
*record = (*C.uchar)(nullPtr)
return 0
}

size := C.size_t(len(r))
*record = (*C.uchar)(C.malloc(size))
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
return C.int(size)
}

//export mem_free
func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so
// will cause calling any function of this library from Python
// ctypes hanging.
C.free(p)
}

func main() {}
58 changes: 55 additions & 3 deletions go/master/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package master

import (
"log"
"os"
"time"

"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)

// Addresser provide the address of the master server.
Expand All @@ -15,16 +17,58 @@ type Addresser interface {
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan []byte
}

// NewClient creates a new Client.
func NewClient(addr Addresser) *Client {
c := &Client{}
c.conn = connection.New()
c.ch = make(chan []byte)
go c.monitorMaster(addr)
go c.getRecords()
return c
}

func (c *Client) getRecords() {
for {
t, err := c.getTask()
if err != nil {
// TODO(helin): wait before move on with next
// getTask call.
log.Println(err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we all use glog instead of log, so we can do glog.Errorf or glog.Infof to separate info log and error log. Use glog.V(10).Infof to log more detailed debug logs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! glog does not work with flag package other than official "flag", can we use https://github.com/sirupsen/logrus instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice lib! This lib seems also able to send logs to "logstash" in json format, this is very useful when run jobs on kubernetes, we can collect logs and search logs using "EFK"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@typhoonzero Thanks for mentioning logstash and EFK, good to know about them!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

continue
}

for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Println(err)
continue
}

s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
for s.Scan() {
c.ch <- s.Record()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line will block util some reader read the record from this channel. How to deal with "batches" like paddle.dataset.batch(some_reader(), 32)? Will adding channel buffers like c.ch = make(chan []byte, batch_size) increase the efficiency?

Copy link
Contributor Author

@helinwang helinwang Jun 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use paddle.reader.buffered for this purpose. Still I think that's a good point. Will add.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

if s.Err() != nil {
log.Println(err, chunk.Path)
}

err = f.Close()
if err != nil {
log.Println(err)
}
}

// We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly
// correct, but a reasonable approximation.
c.taskFinished(t.ID)
}
}

func (c *Client) monitorMaster(addr Addresser) {
lastMaster := ""
monitor := func() {
Expand Down Expand Up @@ -69,14 +113,22 @@ func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
}

// GetTask gets a new task from the master server.
func (c *Client) GetTask() (Task, error) {
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
return t, err
}

// TaskFinished tells the master server a task is finished.
func (c *Client) TaskFinished(taskID int) error {
func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
}

// NextRecord returns next record in the dataset.
//
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() []byte {
return <-c.ch
}
121 changes: 121 additions & 0 deletions go/master/client_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package master

import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"

log "github.com/sirupsen/logrus"

"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)

const (
totalTask = 20
chunkPerTask = 10
)

func init() {
log.SetLevel(log.ErrorLevel)
}

type TestAddresser string

func (a TestAddresser) Address() string {
return string(a)
}

func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0"

l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}

ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}

go func(l net.Listener) {
s := NewService(chunkPerTask, time.Second, 1)
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
panic(err)
}

mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)

f, err := os.Create(path)
if err != nil {
panic(err)
}

for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
w.Write(nil)
// call Close to force RecordIO writing a chunk.
w.Close()
}
f.Close()

// Manually intialize client to avoid calling c.getRecords()
c := &Client{}
c.conn = connection.New()
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p)))
c.SetDataset([]string{path})

checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask()
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
tasks = append(tasks, task)
}

_, err = c.getTask()
if err == nil {
t.Fatalf("Should get error, pass: %d\n", i)
}

err = c.taskFinished(tasks[0].ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
t.Fatal(err)
}
tasks = append(tasks, task)

for _, task := range tasks {
err = c.taskFinished(task.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
}
}

for i := 0; i < 10; i++ {
checkOnePass(i)
}
}
Loading