Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ endif(WITH_GPU)

add_subdirectory(proto)
add_subdirectory(paddle)
add_subdirectory(go/master/c)
add_subdirectory(python)
add_subdirectory(go/pserver/cclient)

Expand Down
12 changes: 4 additions & 8 deletions go/cmake/golang.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,23 @@ function(GO_LIBRARY NAME BUILD_TYPE)

# automatically get all dependencies specified in the source code
# for given target.
add_custom_target(goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)
add_custom_target(${NAME}_goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)

# make a symlink that references Paddle inside $GOPATH, so go get
# will use the local changes in Paddle rather than checkout Paddle
# in github.
add_custom_target(copyPaddle
add_custom_target(${NAME}_copyPaddle
COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle)
add_dependencies(goGet copyPaddle)
add_dependencies(${NAME}_goGet ${NAME}_copyPaddle)

add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-gcflags=-shared -asmflags=-shared -installsuffix=_shared -a
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})

add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
add_dependencies(${NAME} goGet)
add_dependencies(${NAME} ${NAME}_goGet)

if(NOT BUILD_TYPE STREQUAL "STATIC")
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin)
endif()
endfunction(GO_LIBRARY)
5 changes: 3 additions & 2 deletions go/connection/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package connection

import (
"errors"
"log"
"net/rpc"
"sync"

log "github.com/sirupsen/logrus"
)

// TODO(helin): add TCP re-connect logic
Expand Down Expand Up @@ -65,7 +66,7 @@ func (c *Conn) Connect(addr string) error {
} else {
err := client.Close()
if err != nil {
log.Println(err)
log.Errorln(err)
}

return errors.New("client already set from a concurrent goroutine")
Expand Down
21 changes: 21 additions & 0 deletions go/master/c/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
cmake_minimum_required(VERSION 3.0)

get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")

project(cxx_go C Go)

include(golang)
include(flags)

set(MASTER_LIB_NAME "paddle_master")
go_library(${MASTER_LIB_NAME} SHARED)

if(PROJ_ROOT)
add_custom_command(OUTPUT ${PROJ_ROOT}/python/paddle/v2/master/lib${MASTER_LIB_NAME}.so
COMMAND rm ${CMAKE_CURRENT_BINARY_DIR}/lib${MASTER_LIB_NAME}.h
COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/lib${MASTER_LIB_NAME}.so ${PROJ_ROOT}/python/paddle/v2/master/
DEPENDS ${MASTER_LIB_NAME})
add_custom_target(paddle_master_shared ALL DEPENDS ${PROJ_ROOT}/python/paddle/v2/master/lib${MASTER_LIB_NAME}.so)
endif(PROJ_ROOT)
110 changes: 110 additions & 0 deletions go/master/c/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package main

/*
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1

typedef int paddle_master_client;
*/
import "C"

import (
"sync"
"unsafe"

"github.com/PaddlePaddle/Paddle/go/master"
log "github.com/sirupsen/logrus"
)

var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client

func add(c *master.Client) C.paddle_master_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
curHandle++
handleMap[client] = c
return client
}

func get(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}

func remove(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
delete(handleMap, client)
return h
}

type addresser string

func (a addresser) Address() string {
return string(a)
}

//export paddle_new_master_client
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr)
c := master.NewClient(addresser(a), bufSize)
return add(c)
}

//export paddle_release_master_client
func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}

//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
var paths []string
for i := 0; i < int(size); i++ {
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
str := C.GoString(*ptr)
paths = append(paths, str)
}
err := c.SetDataset(paths)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}

return C.PADDLE_MASTER_OK
}

//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r := c.NextRecord()
if len(r) == 0 {
*record = (*C.uchar)(nullPtr)
return 0
}

size := C.size_t(len(r))
*record = (*C.uchar)(C.malloc(size))
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
return C.int(size)
}

//export mem_free
func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so
// will cause calling any function of this library from Python
// ctypes hanging.
C.free(p)
}

func main() {}
69 changes: 62 additions & 7 deletions go/master/client.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package master

import (
"log"
"os"
"time"

"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
log "github.com/sirupsen/logrus"
)

// Addresser provide the address of the master server.
Expand All @@ -15,16 +17,61 @@ type Addresser interface {
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan []byte
}

// NewClient creates a new Client.
func NewClient(addr Addresser) *Client {
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func NewClient(addr Addresser, bufSize int) *Client {
c := &Client{}
c.conn = connection.New()
c.ch = make(chan []byte, bufSize)
go c.monitorMaster(addr)
go c.getRecords()
return c
}

func (c *Client) getRecords() {
for {
t, err := c.getTask()
if err != nil {
// TODO(helin): wait before move on with next
// getTask call.
log.Errorln(err)
continue
}

for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Errorln(err)
continue
}

s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
for s.Scan() {
c.ch <- s.Record()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line will block util some reader read the record from this channel. How to deal with "batches" like paddle.dataset.batch(some_reader(), 32)? Will adding channel buffers like c.ch = make(chan []byte, batch_size) increase the efficiency?

Copy link
Contributor Author

@helinwang helinwang Jun 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use paddle.reader.buffered for this purpose. Still I think that's a good point. Will add.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

if s.Err() != nil {
log.Errorln(err, chunk.Path)
}

err = f.Close()
if err != nil {
log.Errorln(err)
}
}

// We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly
// correct, but a reasonable approximation.
c.taskFinished(t.ID)
}
}

func (c *Client) monitorMaster(addr Addresser) {
lastMaster := ""
monitor := func() {
Expand All @@ -35,12 +82,12 @@ func (c *Client) monitorMaster(addr Addresser) {
if curMaster == "" {
err := c.conn.Close()
if err != nil {
log.Println(err)
log.Errorln(err)
}
} else {
err := c.conn.Connect(curMaster)
if err != nil {
log.Println(err)
log.Errorln(err)

// connect to addr failed, set
// to last known addr in order
Expand Down Expand Up @@ -69,14 +116,22 @@ func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
}

// GetTask gets a new task from the master server.
func (c *Client) GetTask() (Task, error) {
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
return t, err
}

// TaskFinished tells the master server a task is finished.
func (c *Client) TaskFinished(taskID int) error {
func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
}

// NextRecord returns next record in the dataset.
//
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() []byte {
return <-c.ch
}
Loading