Skip to content

Paddle Python API 设计文档(初稿) #1069

Closed
@jacquesqiao

Description

@jacquesqiao

一个典型的训练过程

gradient_machine.startPass()
updater.startPass()
for each_batch in data:
    gradient_machine.startBatch()
    updater.startBatch()

    gradient_machine.train()

    updater.finishBatch()
    gradient_machine.finishBatch()
updater.finishPass()
gradient_machine.finishPass()

用一个类似调用链的东西,把操作分离开。比如上面的例子可以被拆成两个RunnerChainItems.

  • GradientMachineOperations
  • UpdaterOperations.

一些核心概念

  • Runner
  • RunnerItem
  • RunnerBuilder

Runner

Runner主要利用GradientMachine层面暴露出来的API,将原来Trainer.cpp的逻辑封装起来,Runner中包含很多个RunnerItem,每个RunnerItem完成Trainer中的部分逻辑,用户可以循环调用Runner的run_pass,Runner内部通过一个一个的RunnerItem完成之前各个组件的功能,比如updater,gradientmachine的forward/backward,parameter save/load等操作,用户无需关心。

Runner的实现

class Runner(object):
    def add_item(self, item):
        """
        Add a runner item to runner.
        """
    def run_one_pass(self):
        """
        Run one pass for runner. The parent argument will passed to context.
        """

构造一个runner的过程

    runner = Runner()

    runner.add_item(ItemA())
    runner.add_item(ItemB())

    with runner:
        runner.run_one_pass()

RunnerItem

RunnerItem is an item in Runner. Runner will composite the
RunnerItems together and invoke the first RunnerChainItem's methods.
And Runner will pass the next chain item's method as next_callback.
If current chain item is the last item. A default next_callback will be
passed.

Context is a global object shared by items.

class RunnerItem(object):
    """
    RunnerItem is an item in Runner. Runner will composite the
    RunnerItems together and invoke the first RunnerChainItem's methods.
    And Runner will pass the next chain item's method as `next_callback`.
    If current chain item is the last item. A default next_callback will be
    passed.

    Context is a global object shared by items.
    """

    def __init__(self):
        pass

    def initialize(self, context, next_callback):
        """
        initialize method. It will be invoked when Runner start to run.

        :param context: a global object shared by items.
        :type context: RunnerContext
        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: None
        :rtype: None
        """
        next_callback(context)

    def finalize(self, next_callback):
        """
        Finalize method. It will be invoked when Runner complete run, and clean
        some state in RunnerItem.

        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: None
        :rtype: None
        """
        next_callback()

    def on_pass_begin(self, next_callback):
        """
        Pass Begin Method. Invoked when a pass begins.

        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: None
        :rtype: None
        """

        next_callback()

    def on_pass_end(self, next_callback):
        """
        Pass End Method. Invoked when a pass ends.

        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: None
        :rtype: None
        """
        next_callback()

    def on_batch_begin(self, next_callback):
        """
        Batch Begin Method. Invoked when a batch begins. Return true if there is
        no more batch could be processed.

        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: True if no more batch could be processed.
        :rtype: bool
        """
        return next_callback()

    def on_batch_end(self, next_callback):
        """
        Batch End Method. Invoked when a batch ends. Return true if there is
        no more batch could be processed.

        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: True if no more batch could be processed.
        :rtype: bool
        """
        return next_callback()

已经实现的RunnerItem

  • CreateGradientMachine
  • BasicLocalParameterUpdaterOps
  • BasicGradientMachineTrainOps
  • BasicGradientMachineTestOps
  • SaveParamsOnPassEnd
  • Counter

RunnerBuilder

将build Runner的过程封装起来,用with_std_local_trainer等辅助函数方式组装一个可以运行的Runner

import paddle.trainer.PyDataProvider2 as dp
from paddle.trainer_config_helpers import *

import mnist_provider
from py_paddle.trainer import *


@network(
    inputs={
        'pixel': dp.dense_vector(784),
        'label': dp.integer_value(10),
    },
    learning_rate=1e-4,
    learning_method=AdamOptimizer(),
    batch_size=1000,
    model_average=ModelAverage(average_window=0.5),
    regularization=L2Regularization(rate=0.5))
def mnist_network(pixel, label):
    hidden1 = fc_layer(input=pixel, size=200)
    hidden2 = fc_layer(input=hidden1, size=200)
    inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
    cost = classification_cost(input=inference, label=label)
    return cost


def main():
    mnist = mnist_network()
    runner = RunnerBuilder(
        network=mnist, device_count=2).with_std_local_trainer(
            method=mnist_provider.process,
            file_list=['./data/raw_data/train']).with_std_tester(
                method=mnist_provider.process,
                file_list=['./data/raw_data/t10k']).build()
    with runner:
        for _ in xrange(2):
            runner.run_one_pass()


if __name__ == '__main__':
    main()

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions