Skip to content

Commit

Permalink
add V4 rec distill (#9921)
Browse files Browse the repository at this point in the history
* support min_area_rect crop

* add check_install

* fix requirement.txt

* fix check_install

* add lanms-neo for drrg

* fix

* fix doc

* fix

* support set gpu_id when inference

* fix #8855

* fix #8855

* opt slim doc

* fix doc bug

* add v4_rec_distill config

* delete debug

* fix comment

* fix comment
  • Loading branch information
LDOUBLEV authored May 15, 2023
1 parent 4251664 commit 1643f26
Show file tree
Hide file tree
Showing 4 changed files with 766 additions and 5 deletions.
231 changes: 231 additions & 0 deletions configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_distill.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
Global:
debug: false
use_gpu: true
epoch_num: 200
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_dkd_400w_svtr_ctc_lcnet_blank_dkd0.1/
save_epoch_step: 40
eval_batch_step:
- 0
- 2000
cal_metric_during_train: true
pretrained_model: null
checkpoints: ./output/rec_dkd_400w_svtr_ctc_lcnet_blank_dkd0.1/latest
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv3.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: L2
factor: 3.0e-05
Architecture:
model_type: rec
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: true
return_all_feats: true
model_type: rec
algorithm: SVTR
Transform: null
Backbone:
name: SVTRNet
img_size:
- 48
- 320
out_char_num: 40
out_channels: 192
patch_merging: Conv
embed_dim:
- 64
- 128
- 256
depth:
- 3
- 6
- 3
num_heads:
- 2
- 4
- 8
mixer:
- Conv
- Conv
- Conv
- Conv
- Conv
- Conv
- Global
- Global
- Global
- Global
- Global
- Global
local_mixer:
- - 5
- 5
- - 5
- 5
- - 5
- 5
last_stage: false
prenorm: true
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: *max_text_length
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: rec
algorithm: SVTR
Transform: null
Backbone:
name: PPLCNetV3
scale: 0.95
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: *max_text_length
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDKDLoss:
weight: 0.1
model_name_pairs:
- - Student
- Teacher
key: head_out
multi_head: true
alpha: 1.0
beta: 2.0
dis_head: gtc
name: dkd
- DistillationCTCLoss:
weight: 1.0
model_name_list:
- Student
key: head_out
multi_head: true
- DistillationNRTRLoss:
weight: 1.0
smoothing: false
model_name_list:
- Student
key: head_out
multi_head: true
- DistillCTCLogits:
weight: 1.0
reduction: mean
model_name_pairs:
- - Student
- Teacher
key: head_out
PostProcess:
name: DistillationCTCLabelDecode
model_name:
- Student
key: head_out
multi_head: true
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: Student
ignore_space: false
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/train_list.txt
ratio_list:
- 1.0
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_workers: 8
use_shared_memory: true
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4
profiler_options: null
76 changes: 76 additions & 0 deletions ppocr/losses/basic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,79 @@ def forward(self, predicts, batch):
elif self.reduction == 'sum':
loss = paddle.sum(loss)
return {'loss': loss}


class KLDivLoss(nn.Layer):
"""
KLDivLoss
"""

def __init__(self):
super().__init__()

def _kldiv(self, x, target, mask=None):
eps = 1.0e-10
loss = target * (paddle.log(target + eps) - x)
if mask is not None:
loss = loss.flatten(0, 1).sum(axis=1)
loss = loss.masked_select(mask).mean()
else:
# batch mean loss
loss = paddle.sum(loss) / loss.shape[0]
return loss

def forward(self, logits_s, logits_t, mask=None):
log_out_s = F.log_softmax(logits_s, axis=-1)
out_t = F.softmax(logits_t, axis=-1)
loss = self._kldiv(log_out_s, out_t, mask)
return loss


class DKDLoss(nn.Layer):
"""
KLDivLoss
"""

def __init__(self, temperature=1.0, alpha=1.0, beta=1.0):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.beta = beta

def _cat_mask(self, t, mask1, mask2):
t1 = (t * mask1).sum(axis=1, keepdim=True)
t2 = (t * mask2).sum(axis=1, keepdim=True)
rt = paddle.concat([t1, t2], axis=1)
return rt

def _kl_div(self, x, label, mask=None):
y = (label * (paddle.log(label + 1e-10) - x)).sum(axis=1)
if mask is not None:
y = y.masked_select(mask).mean()
else:
y = y.mean()
return y

def forward(self, logits_student, logits_teacher, target, mask=None):
gt_mask = F.one_hot(
target.reshape([-1]), num_classes=logits_student.shape[-1])
other_mask = 1 - gt_mask
logits_student = logits_student.flatten(0, 1)
logits_teacher = logits_teacher.flatten(0, 1)
pred_student = F.softmax(logits_student / self.temperature, axis=1)
pred_teacher = F.softmax(logits_teacher / self.temperature, axis=1)
pred_student = self._cat_mask(pred_student, gt_mask, other_mask)
pred_teacher = self._cat_mask(pred_teacher, gt_mask, other_mask)
log_pred_student = paddle.log(pred_student)
tckd_loss = self._kl_div(log_pred_student,
pred_teacher) * (self.temperature**2)
pred_teacher_part2 = F.softmax(
logits_teacher / self.temperature - 1000.0 * gt_mask, axis=1)
log_pred_student_part2 = F.log_softmax(
logits_student / self.temperature - 1000.0 * gt_mask, axis=1)
nckd_loss = self._kl_div(log_pred_student_part2,
pred_teacher_part2) * (self.temperature**2)

loss = self.alpha * tckd_loss + self.beta * nckd_loss

return loss
6 changes: 3 additions & 3 deletions ppocr/losses/combined_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from .ace_loss import ACELoss
from .rec_sar_loss import SARLoss

from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationSARLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationCTCLoss, DistillCTCLogits
from .distillation_loss import DistillationSARLoss, DistillationNRTRLoss
from .distillation_loss import DistillationDMLLoss, DistillationKLDivLoss, DistillationDKDLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
from .distillation_loss import DistillationVQASerTokenLayoutLMLoss, DistillationSERDMLLoss
from .distillation_loss import DistillationLossFromOutput
Expand Down
Loading

0 comments on commit 1643f26

Please sign in to comment.