Skip to content

Add greedy CTC evaluator python API #7596

Closed
@wanghaoshuang

Description

@wanghaoshuang

This issue depend on #7527
CTC evaluator = top-k_op + ctc_align_op + edit_distance_op

Test script:

import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from paddle.v2.fluid import core

x = fluid.layers.data(name='x', shape=[8], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
ctc_result = fluid.layers.ctc_greedy_decoder(input=x, blank=0)
edit_distance = fluid.evaluator.EditDistance(input=ctc_result,label=y)
print "step1"

place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
print "step2"
edit_distance.reset(exe)
batch_num = 2
for i in range(batch_num):
    print "step3"
    y_data = np.random.randint(0, 8, [7, 1])
    y_lod = [[0, 2, 4, 7]]
    y_tensor = core.LoDTensor()
    y_tensor.set(y_data, place)
    y_tensor.set_lod(y_lod)

    x_data = np.random.uniform(0.1, 1, [11, 8]).astype("float32")
    x_lod = [[0, 3, 5, 11]]
    x_tensor = core.LoDTensor()
    x_tensor.set(x_data, place)
    x_tensor.set_lod(x_lod)

    cost, = exe.run(fluid.default_main_program(),
                              feed={
                                   'x': x_tensor,
                                   'y': y_tensor
                                },
                                fetch_list=edit_distance.metrics)
    pass_error = edit_distance.eval(exe)
    print "cost: %s" % cost
    print "pass_id=%d; pass_error=%s" % (i, str(pass_error))

pass_error = edit_distance.eval(exe)
print "total_pass_error=%s" % str(pass_error)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions