Skip to content

Commit

Permalink
Warpctc mult label length (apache#2650)
Browse files Browse the repository at this point in the history
* multi length label

support multi length label

Update README.md

* reset dmlc core

* fix lint
  • Loading branch information
xlvector authored and tqchen committed Jul 8, 2016
1 parent 1d9d6b1 commit acdae67
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 30 deletions.
5 changes: 5 additions & 0 deletions example/warpctc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,8 @@ Following code show detail construction of the net:
return sm
```

## Support multi label length

If you label length is smalled than or equal to b. You should provide labels with length b, and for those samples which label length is smaller than b, you should append 0 to label data to make it have length b.

Here, 0 is reserved for blank label.
24 changes: 17 additions & 7 deletions example/warpctc/lstm_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@ def provide_label(self):
return [(n, x.shape) for n, x in zip(self.label_names, self.label)]

def gen_rand():
num = random.randint(0, 9999)
buf = str(num)
while len(buf) < 4:
buf = "0" + buf
buf = ""
max_len = random.randint(3,4)
for i in range(max_len):
buf += str(random.randint(0,9))
return buf

def get_label(buf):
ret = np.zeros(4)
for i in range(4):
for i in range(len(buf)):
ret[i] = 1 + int(buf[i])
if len(buf) == 3:
ret[3] = 0
return ret

class OCRIter(mx.io.DataIter):
Expand Down Expand Up @@ -96,15 +98,23 @@ def ctc_label(p):
if c2 == 0 or c2 == c1:
continue
ret.append(c2)
return ret
return ret

def remove_blank(l):
ret = []
for i in range(len(l)):
if l[i] == 0:
break
ret.append(l[i])
return ret

def Accuracy(label, pred):
global BATCH_SIZE
global SEQ_LENGTH
hit = 0.
total = 0.
for i in range(BATCH_SIZE):
l = label[i]
l = remove_blank(label[i])
p = []
for k in range(SEQ_LENGTH):
p.append(np.argmax(pred[k * BATCH_SIZE + i]))
Expand Down
2 changes: 1 addition & 1 deletion example/warpctc/toy_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __iter__(self):
num, img = gen_rand()
data.append(img)
label.append(get_label(num))

data_all = [mx.nd.array(data)] + self.init_state_arrays
label_all = [mx.nd.array(label)]
data_names = ['data'] + init_state_names
Expand Down
80 changes: 58 additions & 22 deletions plugin/warpctc/warpctc-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,33 @@ class WarpCTCOp : public Operator {
Softmax(out_tensor, data_tensor);
}

std::vector<int> labelLengths(const int * flat_labels, int minibatch,
int size, int blank, int * total_length) {
CHECK_EQ(param_.label_length * minibatch, size)
<< "label size should = label_length * minibatch";
std::vector<int> ret(minibatch, 0);
for (int i = 0; i < size; i++) {
if (flat_labels[i] == blank) {
continue;
}
int b = i / param_.label_length;
ret[b]++;
(*total_length)++;
}
return ret;
}

void removeBlank(const int * flat_labels, int * cpu_labels,
int size, int blank) {
int k = 0;
for (int i = 0; i < size; i++) {
if (flat_labels[i] != blank) {
cpu_labels[k] = flat_labels[i];
k += 1;
}
}
}

virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
Expand Down Expand Up @@ -111,11 +138,37 @@ class WarpCTCOp : public Operator {
for (int i = 0; i < minibatch; i++) {
input_lengths.push_back(T);
}
std::vector<int> label_lengths;
for (int i = 0; i < minibatch; i++) {
label_lengths.push_back(param_.label_length);

#if MXNET_USE_CUDA
cudaError_t cuda_status;
#endif
float* activations = static_cast<float*>(data.dptr_);
int* flat_labels = static_cast<int*>(label.dptr_);
int* cpu_raw_labels = flat_labels;
float* grads = static_cast<float*>(in_grad[warpctc_enum::kData].dptr_);
if (data.dev_mask_ == gpu::kDevMask) {
#if MXNET_USE_CUDA
cpu_raw_labels = reinterpret_cast<int*>(malloc(sizeof(int) * label.Size()));
cuda_status = cudaMemcpyAsync(cpu_raw_labels, flat_labels,
label.Size()*sizeof(int),
cudaMemcpyDeviceToHost,
ctx.get_stream<gpu>()->stream_);
CHECK_EQ(cuda_status, cudaSuccess) << "cuda memcpy label error";
#endif
} else {
LOG(FATAL) << "Unknown device type " << data.dev_mask_;
}

int total_label_length = 0;
std::vector<int> label_lengths = labelLengths(cpu_raw_labels,
minibatch,
label.Size(),
0, &total_label_length);
int* cpu_labels = reinterpret_cast<int*>(
malloc(sizeof(int) * total_label_length));
removeBlank(cpu_raw_labels, cpu_labels, label.Size(), 0);
free(cpu_raw_labels);

size_t alloc_bytes;
throw_on_error(get_workspace_size(label_lengths.data(),
input_lengths.data(),
Expand All @@ -125,32 +178,14 @@ class WarpCTCOp : public Operator {
"Error: get_workspace_size in inf_test");
void* ctc_workspace;

#if MXNET_USE_CUDA
cudaError_t cuda_status;
#endif
float* activations = static_cast<float*>(data.dptr_);
int* flat_labels = static_cast<int*>(label.dptr_);
int* cpu_labels = flat_labels;
float* grads = static_cast<float*>(in_grad[warpctc_enum::kData].dptr_);

if (data.dev_mask_ == cpu::kDevMask) {
ctc_workspace = malloc(alloc_bytes);
} else if (data.dev_mask_ == gpu::kDevMask) {
#if MXNET_USE_CUDA
cpu_labels = reinterpret_cast<int*>(malloc(sizeof(int) * label.Size()));
cuda_status = cudaMemcpyAsync(cpu_labels, flat_labels,
label.Size()*sizeof(int),
cudaMemcpyDeviceToHost,
ctx.get_stream<gpu>()->stream_);
CHECK_EQ(cuda_status, cudaSuccess) << "cuda memcpy label error";

cuda_status = cudaMalloc(&ctc_workspace, alloc_bytes);
CHECK_EQ(cuda_status, cudaSuccess) << "cuda malloc worksapce fail";
#endif
} else {
LOG(FATAL) << "Unknown device type " << data.dev_mask_;
}

std::vector<float> costs(minibatch);
throw_on_error(compute_ctc_loss(activations,
grads,
Expand All @@ -163,12 +198,14 @@ class WarpCTCOp : public Operator {
ctc_workspace,
info),
"Error: compute_ctc_loss");

if (data.dev_mask_ == cpu::kDevMask) {
free(ctc_workspace);
} else if (data.dev_mask_ == gpu::kDevMask) {
#if MXNET_USE_CUDA
cuda_status = cudaFree(ctc_workspace);
CHECK_EQ(cuda_status, cudaSuccess) << "cuda free workspace fail";
free(cpu_labels);
#endif
}
}
Expand Down Expand Up @@ -207,7 +244,6 @@ class WarpCTCProp : public OperatorProperty {
if (dshape.ndim() == 0) return false;
TShape label_shape(dshape.ndim() - 1);
label_shape[0] = param_.label_length * (dshape[0] / param_.input_length);
std::cout << "infer label shape: " << label_shape[0] << std::endl;
SHAPE_ASSIGN_CHECK(*in_shape, warpctc_enum::kLabel, label_shape);

out_shape->clear();
Expand Down

0 comments on commit acdae67

Please sign in to comment.