-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Implement master client for reading training tasks #2468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
fa5c3f1
7b9080e
4970484
e730bc5
094106a
6cd1441
8742441
4b6243c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
package main | ||
|
||
/* | ||
#include <stdlib.h> | ||
#include <string.h> | ||
#include <stdio.h> | ||
|
||
#define PADDLE_MASTER_OK 0 | ||
#define PADDLE_MASTER_ERROR -1 | ||
|
||
typedef int paddle_master_client; | ||
*/ | ||
import "C" | ||
|
||
import ( | ||
"log" | ||
"sync" | ||
"unsafe" | ||
|
||
"github.com/PaddlePaddle/Paddle/go/master" | ||
) | ||
|
||
var nullPtr = unsafe.Pointer(uintptr(0)) | ||
var mu sync.Mutex | ||
var handleMap = make(map[C.paddle_master_client]*master.Client) | ||
var curHandle C.paddle_master_client | ||
|
||
func add(c *master.Client) C.paddle_master_client { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
client := curHandle | ||
curHandle++ | ||
handleMap[client] = c | ||
return client | ||
} | ||
|
||
func get(client C.paddle_master_client) *master.Client { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
return handleMap[client] | ||
} | ||
|
||
func remove(client C.paddle_master_client) *master.Client { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
h := handleMap[client] | ||
delete(handleMap, client) | ||
return h | ||
} | ||
|
||
type addresser string | ||
|
||
func (a addresser) Address() string { | ||
return string(a) | ||
} | ||
|
||
//export paddle_new_master_client | ||
func paddle_new_master_client(addr *C.char) C.paddle_master_client { | ||
a := C.GoString(addr) | ||
c := master.NewClient(addresser(a)) | ||
return add(c) | ||
} | ||
|
||
//export paddle_release_master_client | ||
func paddle_release_master_client(client C.paddle_master_client) { | ||
remove(client) | ||
} | ||
|
||
//export paddle_set_dataset | ||
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int { | ||
c := get(client) | ||
var paths []string | ||
for i := 0; i < int(size); i++ { | ||
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path))) | ||
str := C.GoString(*ptr) | ||
paths = append(paths, str) | ||
} | ||
err := c.SetDataset(paths) | ||
if err != nil { | ||
log.Println(err) | ||
return C.PADDLE_MASTER_ERROR | ||
} | ||
|
||
return C.PADDLE_MASTER_OK | ||
} | ||
|
||
//export paddle_next_record | ||
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { | ||
c := get(client) | ||
r := c.NextRecord() | ||
if len(r) == 0 { | ||
*record = (*C.uchar)(nullPtr) | ||
return 0 | ||
} | ||
|
||
size := C.size_t(len(r)) | ||
*record = (*C.uchar)(C.malloc(size)) | ||
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size) | ||
return C.int(size) | ||
} | ||
|
||
//export mem_free | ||
func mem_free(p unsafe.Pointer) { | ||
// "free" may be a better name for this function, but doing so | ||
// will cause calling any function of this library from Python | ||
// ctypes hanging. | ||
C.free(p) | ||
} | ||
|
||
func main() {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,11 @@ package master | |
|
||
import ( | ||
"log" | ||
"os" | ||
"time" | ||
|
||
"github.com/PaddlePaddle/Paddle/go/connection" | ||
"github.com/PaddlePaddle/recordio" | ||
) | ||
|
||
// Addresser provide the address of the master server. | ||
|
@@ -15,16 +17,58 @@ type Addresser interface { | |
// Client is the client of the master server. | ||
type Client struct { | ||
conn *connection.Conn | ||
ch chan []byte | ||
} | ||
|
||
// NewClient creates a new Client. | ||
func NewClient(addr Addresser) *Client { | ||
c := &Client{} | ||
c.conn = connection.New() | ||
c.ch = make(chan []byte) | ||
go c.monitorMaster(addr) | ||
go c.getRecords() | ||
return c | ||
} | ||
|
||
func (c *Client) getRecords() { | ||
for { | ||
t, err := c.getTask() | ||
if err != nil { | ||
// TODO(helin): wait before move on with next | ||
// getTask call. | ||
log.Println(err) | ||
continue | ||
} | ||
|
||
for _, chunk := range t.Chunks { | ||
f, err := os.Open(chunk.Path) | ||
if err != nil { | ||
log.Println(err) | ||
continue | ||
} | ||
|
||
s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1) | ||
for s.Scan() { | ||
c.ch <- s.Record() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line will block util some reader read the record from this channel. How to deal with "batches" like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
} | ||
|
||
if s.Err() != nil { | ||
log.Println(err, chunk.Path) | ||
} | ||
|
||
err = f.Close() | ||
if err != nil { | ||
log.Println(err) | ||
} | ||
} | ||
|
||
// We treat a task as finished whenever the last data | ||
// instance of the task is read. This is not exactly | ||
// correct, but a reasonable approximation. | ||
c.taskFinished(t.ID) | ||
} | ||
} | ||
|
||
func (c *Client) monitorMaster(addr Addresser) { | ||
lastMaster := "" | ||
monitor := func() { | ||
|
@@ -69,14 +113,22 @@ func (c *Client) SetDataset(globPaths []string) error { | |
return c.conn.Call("Service.SetDataset", globPaths, nil) | ||
} | ||
|
||
// GetTask gets a new task from the master server. | ||
func (c *Client) GetTask() (Task, error) { | ||
// getTask gets a new task from the master server. | ||
func (c *Client) getTask() (Task, error) { | ||
var t Task | ||
err := c.conn.Call("Service.GetTask", 0, &t) | ||
return t, err | ||
} | ||
|
||
// TaskFinished tells the master server a task is finished. | ||
func (c *Client) TaskFinished(taskID int) error { | ||
func (c *Client) taskFinished(taskID int) error { | ||
return c.conn.Call("Service.TaskFinished", taskID, nil) | ||
} | ||
|
||
// NextRecord returns next record in the dataset. | ||
// | ||
// NextRecord will block until the next record is available. It is | ||
// thread-safe. | ||
func (c *Client) NextRecord() []byte { | ||
return <-c.ch | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
package master | ||
|
||
import ( | ||
"fmt" | ||
"net" | ||
"net/http" | ||
"net/rpc" | ||
"os" | ||
"strconv" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
log "github.com/sirupsen/logrus" | ||
|
||
"github.com/PaddlePaddle/Paddle/go/connection" | ||
"github.com/PaddlePaddle/recordio" | ||
) | ||
|
||
const ( | ||
totalTask = 20 | ||
chunkPerTask = 10 | ||
) | ||
|
||
func init() { | ||
log.SetLevel(log.ErrorLevel) | ||
} | ||
|
||
type TestAddresser string | ||
|
||
func (a TestAddresser) Address() string { | ||
return string(a) | ||
} | ||
|
||
func TestGetFinishTask(t *testing.T) { | ||
const path = "/tmp/master_client_test_0" | ||
|
||
l, err := net.Listen("tcp", ":0") | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
ss := strings.Split(l.Addr().String(), ":") | ||
p, err := strconv.Atoi(ss[len(ss)-1]) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
go func(l net.Listener) { | ||
s := NewService(chunkPerTask, time.Second, 1) | ||
server := rpc.NewServer() | ||
err := server.Register(s) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
mux := http.NewServeMux() | ||
mux.Handle(rpc.DefaultRPCPath, server) | ||
err = http.Serve(l, mux) | ||
if err != nil { | ||
panic(err) | ||
} | ||
}(l) | ||
|
||
f, err := os.Create(path) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
for i := 0; i < totalTask*chunkPerTask; i++ { | ||
w := recordio.NewWriter(f, -1, -1) | ||
w.Write(nil) | ||
// call Close to force RecordIO writing a chunk. | ||
w.Close() | ||
} | ||
f.Close() | ||
|
||
// Manually intialize client to avoid calling c.getRecords() | ||
c := &Client{} | ||
c.conn = connection.New() | ||
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p))) | ||
c.SetDataset([]string{path}) | ||
|
||
checkOnePass := func(i int) { | ||
var tasks []Task | ||
for idx := 0; idx < totalTask; idx++ { | ||
task, err := c.getTask() | ||
if err != nil { | ||
t.Fatalf("Error: %v, pass: %d\n", err, i) | ||
} | ||
tasks = append(tasks, task) | ||
} | ||
|
||
_, err = c.getTask() | ||
if err == nil { | ||
t.Fatalf("Should get error, pass: %d\n", i) | ||
} | ||
|
||
err = c.taskFinished(tasks[0].ID) | ||
if err != nil { | ||
t.Fatalf("Error: %v, pass: %d\n", err, i) | ||
} | ||
tasks = tasks[1:] | ||
task, err := c.getTask() | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
tasks = append(tasks, task) | ||
|
||
for _, task := range tasks { | ||
err = c.taskFinished(task.ID) | ||
if err != nil { | ||
t.Fatalf("Error: %v, pass: %d\n", err, i) | ||
} | ||
} | ||
} | ||
|
||
for i := 0; i < 10; i++ { | ||
checkOnePass(i) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we all use
glog
instead oflog
, so we can doglog.Errorf
orglog.Infof
to separate info log and error log. Useglog.V(10).Infof
to log more detailed debug logsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! glog does not work with flag package other than official "flag", can we use https://github.com/sirupsen/logrus instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice lib! This lib seems also able to send logs to "logstash" in json format, this is very useful when run jobs on kubernetes, we can collect logs and search logs using "EFK"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@typhoonzero Thanks for mentioning logstash and EFK, good to know about them!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.