Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions doc/design/api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Design Doc: New Paddle API

To write a Paddle program using the current API, we would have to write two Python source files -- one defines the data provider and the other defines the network or run the for loop. This doesn't work well with Notebooks. So we decide to redesign the API. This document describes the basic design concerns.

## Basic Concepts

The API design depends on basic concepts about deep learning.

### Model

For deep learning, a model includes two parts: the topology and parameters. Currently, the concept *model* in Paddle contains only the topology, and parameters are in another concept *gradient machine*. This differs from the intuition and makes it difficult to save/load models. In this design, we should keep both topology and parameters in a *model*.
Copy link
Contributor

@helinwang helinwang Jan 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is how tensorflow do it, just for reference:

Copy link
Collaborator Author

@wangkuiyi wangkuiyi Jan 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一点我不是特别理解,为什么 graph (topology) 和 weights 要分开呢?按照“model是graph加上parameterss(weights)”的概念,貌似把两者放在一个 class Model 里更自然呀?


Algorithms like GAN requires the flexibility to temporarily compose multiple models into one, while keeping each of them workable alone. We will show later that we don't need model composite API; instead, we can use composite gradient machines.

### Gradient Machine and Updater

In order to learn the model, we need to run the error backpropagation algorithm iteratively. In each iteration, we run the forward algorithm with a minibatch of data. This updates the outputs of layers. Then we run a backward algorithm which computes the gradients of every parameter. The outputs and gradients are not part of the model; instead, they are side effects of the training process and should be maintained in the trainer.

After the invocation of the backward algorithm, we should update model parameters using the gradients and parameters like learning rate. This step might involve communications with the parameter server in the case of distributed training. This complexity motivates us to separate the trainer into two concepts:

1. *gradient machine*, which computes and maintains layer outputs and gradients, and

1. *updater*, which encapsulates the updating algorithm.

It seems that *cost function* is a property of *gradient machine*?

### Data

Models are trained using data sets. We hope to provide a set of utility data sets encapsulated in Python packages like `paddle.data.amazon.product_review` and `paddle.data.wikipedia.articles`. A reasonable idea might be that in each of these packages, we provide a `new` function that returns a reader object or a Python iterator. And the *reader* has a read method `read`, which, once called, returns a minibatch and a flag indicating if it reaches the end of a data set. For online learning, the flag would always be False. In this way, we don't have to have two concepts: *pass* and *iteration*; instead, we need only the latter.

## Examples

### A Simple Network

```python
gm = paddle.gradient_machine.new(model) # gm uses default cost function.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one thing we didn't discuss yet. How does user describe a neural network?

The most convenient way to implementation is to use a Python method for a network. For example

@network(input={'x': dense_vector(123), 'label': integer_value(10)})
def sample_network(x, label):
    hidden = fc_layer(input=x, size=100)
    prediction = fc_layer(input=hidden, size=label.size, act=SoftmaxActivation())
    return classification_cost(input=prediction, label=label)

# Create the model.
model = sample_network()
model.randomParams()

model.save/load() ...

Another way to define neural network topology is passing the final value to a model creator. For example

x = data_layer(type=dense_vector(123))
hidden = fc_layer(input=x, size=100)
...
loss = classification_cost(input=prediction, label=label)

model = paddle.ModelCreator(loss)
...

It is hard to implement that in Paddle.

Copy link
Collaborator Author

@wangkuiyi wangkuiyi Jan 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在我写的正文里,没有强调,但是在我给徐老师和鹤麟的回复里有强调:

  1. 没有model这个概念了等价于network。所以应该既没有model,也没有model creator的概念才对。
  2. 每个network用它的output layer指代。具体请参见 Add design/api.md #1088 (comment)

另外,我理解API的设计思想最好不要用Python来描述,容易依赖到Python独特的语法,比如 @network 这样的decorator。API 应该是用“带对象的C”描述也能清晰的,才能支持各种 client languages。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同意, decorator给用户带来了极大的困扰

ud = paddle.updater.new_simple() # A simple updater doesn't work with parameter servers.
rd = paddle.data.amazon.product_review.new()
mt = paddle.metric.new_acc()

for mb, pass_end in rd.read():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need a way to specify minibatch size, either in reader.new(int batch_size) or reader.read(int batch_size)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实。可以通过read函数指定mini batch size,也可以在创建reader的时候指定,比如

rd = paddle.data.amazon.product_review.new(minibatch_size=100)

gm.feed(mb)
ud.update(gm)
Copy link
Member

@jacquesqiao jacquesqiao Jan 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该需要加上

gm.forward()
gm.backward()
gm.update()

Copy link
Collaborator Author

@wangkuiyi wangkuiyi Jan 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦。我这里忘了说明白——gm.feed调用了forward和 backward。我加上了一个说明。

gm.update是什么意思呢?是更新模型参数吗?我以为是 ud.update 来做这个事儿的。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实应该是 ud.update(gm) 而不是 gm.update,我说的不对~

mt.append(gm, mb) # compute and record the accuracy of this minibatch.
if pass_end:
log(mt.flash()) # output aggregated accuracy records of this pass and reset mt.
```

In this example, `GradientMachine.feed` is a convenience that calls `GradientMachine.forward` and `GradientMachine.backward`.

### A GAN Example

```python
input_gen = paddle.layer.input(...)
hidden_gen = paddle.layer.fc(input_gen, ...)
output_gen = paddle.layer.fc(hidden_gen, ...)

# gm_gen maintains layer outputs and gradients of model gen.
gm_gen = paddle.gradient_machine.new(output_gen)

input_dis = paddle.layer.input(...)
hidden_dis = paddle.layer.fc(intput_dis, ...)
output_dis = paddle.layer.softmax(hidden_dis, ...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For GAN, gm_dis and gm_gen update different subset of parameters. We need to have a mechanism to specify this. In the current GAN example, this is achieved by setting is_static according to is_discriminator_training. This is possible in the current GAN example because the configs of gm_gen and gm_dis are actually generated differently.

Copy link
Collaborator Author

@wangkuiyi wangkuiyi Jan 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我理解有这个需求。我这个设计里是这样实现的(但是可能文字里没有说清楚):

  1. 每个“部分”描述成一个“网络”

  2. 每个“网络”用其output指代。

    比如例子里有 output_disoutput_gen。这可以通过在每个layer中记录其所有input layers来实现,这样给定一个output layer,我们即可trace到整个网络中的所有layers。

  3. 每个“网络”的更新信息维护在一个 gradient machine。

    例子中,有两个网络,但是却有三个gradient machines:gm_dis 对应 output_dis,gm_gen 对应 output_gen,gm_comp 对应 output_gen 和 output_dis 的组合。

  4. 对每个网络的更新,是通过 updater 来实现的。

    updater的输入是 gradient machine,因为通过每次forward/backward 调用计算得到的 layer outputs (activations) 和 gradients 都保存在 gradient machine 里了,而且通过 gradient machine 可以查到其对应的network的信息。

    在下面例子里:

    ud.update(gm_dis)
    

    利用 gm_dis 更新其对应的 output_dis 网络,其中没有显示指定需要被更新的部分,所以是更新gm_dis 对应的整个网络;而

    ud.update(gm_comp, output_gen)
    

    是利用 gm_comp 更新其对应的网络中的 output_gen 子网络。其中 output_gen 是显示指定的需要被更新的部分。

# gm_dis maintains layer outputs and gradients of model dis.
gm_dis = paddle.gradient_machine.new(output_dis)

# gm_comp maintains layer outputs and gradients of both gen and dis.
gm_comp = paddle.gradient_machine.compose(output_gen, output_dis)

ud = paddle.updater.new_simple()

rd = paddle.data.mnist.new()

for mb, pass_end in rd.read():
fake = gm_gen.forward(mb)
fake_label = paddle.input.new(False, ...)
real_label = paddle.input.new(True, ...)
gm_dis.feed([fake, fake_label])
gm_dis.feed([mb, real_label])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this question is naive, will the computed gradient from second feed override the gradients from first feed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个设计里是这么假设的:第二次 feed 调用产生的 layer outputs 会覆盖之前一次 feed调用产生的;第二次 feed 调用产生的gradients 也会覆盖之前一次调用产生的。

ud.update(gm_dis)

gm_comp.feed([mb, real_label])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since updater will only need to update gradients of output_gen's predecessors'. One possible optimization here is to let gradient machine know this information. So it does not need to save any unrelated activation and gradients. E.g.,

gm_comp.feed([mb, real_label], output_gen)
ud.update(gm_comp, output_gen)

Maybe I don't have enough context, but from a first glance, I feel maybe it's easier that gradient machine be stateless (current design saves gradients inside gradient machine). E.g.,

// pseudo code
type gradientMachine
type gradient
type updater

var gm_comp gradientMachine
var ud updater
var gradient g = gm_comp.feed([mb, real_label], output_gen)
ud.update(g)

Copy link
Collaborator Author

@wangkuiyi wangkuiyi Jan 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可能是我没有说清楚,所以引起误会了。gm_comp 这个 gradient machine 对应的是 gen 和 dis 两个网络的组合。这个组合的输入是 gen 的输入,所以 input 只能是 mb。

此外,gradient machine被设计成有状态是intentionally的,因为在GAN的例子里,我们有两个“网络”,dis 和 gen,但是却需要记录三种 activation+gradients:dis, gen, comp。

ud.update(gm_comp, output_gen) # updates only the model whose output layer is output_gen.
```

A key point here is that we use the output layer to indicate a model. I think that we can achieve this as long as each layer knows about its predecessors so that we can trace from the output layer upward till the input layers. Please be aware that we didn't compose two models in above example code; instead, we only created a gradient machine which covers both `model_gen` and `model_dis`.
Copy link
Contributor

@helinwang helinwang Jan 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for reference: If I understand correctly, tensorflow do it by state what sub-graph needs update

output_gen_min = tf.train.AdamOptimizer(1e-2).minimize(output_gen)
output_gen_min.run(feed_dict={x: input, y_: label})

Another use case is we only want to update some weights inside a subgraph.
E.g., we want to update fc layer weights of ouput_gen but not weights of hidden_gen (predecessors of ouput_gen). (I think people call it fine tune)

Tensorflow allow explicitly state weights that needs update.

fine_tune_step = tf.train.AdamOptimizer(1e-2).minimize(cross_entropy, var_list=[weights_of_output_gen])
fine_tune_step.run(feed_dict={x: input, y_: label})

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理解了。在这个设计里,也有类似的描述:

ud.update(gm_dis)

利用 gm_dis 更新其对应的 output_dis 网络,其中没有显示指定需要被更新的部分,所以是更新gm_dis 对应的整个网络;而

ud.update(gm_comp, output_gen)

是利用 gm_comp 更新其对应的网络中的 output_gen 子网络。其中 output_gen 是显示指定的需要被更新的部分。