Skip to content

基于CCPD的车牌号检测和识别案例 #744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 28, 2024
152 changes: 152 additions & 0 deletions examples/license_plate_detection_and_recognition/db_r50_ccpd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: 'O0'
seed: 42
log_interval: 10
val_while_train: False
drop_overflow_update: False

model:
type: det
transform: null
backbone:
name: det_resnet50
pretrained: False
neck:
name: DBFPN
out_channels: 256
bias: False
head:
name: DBHead
k: 50
bias: False
adaptive: True
pretrained: https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50_synthtext-40655acb.ckpt

postprocess:
name: DBPostprocess
box_type: quad # whether to output a polygon or a box
binary_thresh: 0.3 # binarization threshold
box_thresh: 0.7 # box score threshold
max_candidates: 1000
expand_ratio: 1.5 # coefficient for expanding predictions

metric:
name: DetMetric
main_indicator: f-score

loss:
name: DBLoss
eps: 1.0e-6
l1_scale: 10
bce_scale: 5
bce_replace: bceloss

scheduler:
scheduler: polynomial_decay
lr: 0.007
num_epochs: 1200
decay_rate: 0.9
warmup_epochs: 3

optimizer:
opt: SGD
filter_bias_and_bn: false
momentum: 0.9
weight_decay: 1.0e-4

# only used for mixed precision training
loss_scaler:
type: dynamic
loss_scale: 512
scale_factor: 2
scale_window: 1000

train:
ckpt_save_dir: './dbnet_ccpd'
dataset_sink_mode: True
dataset:
type: DetDataset
dataset_root: path/to/DBNet_DataSets
data_dir: train/images
label_file: train/train_det_gt.txt
sample_ratio: 1.0
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- DetLabelEncode:
- RandomColorAdjust:
brightness: 0.1255 # 32.0 / 255
saturation: 0.5
- RandomHorizontalFlip:
p: 0.5
- RandomScale:
scale_range: [ 0.5, 3.0 ]
p: 1.0
- RandomCropWithBBox:
max_tries: 10
min_crop_ratio: 0.1
crop_size: [ 640, 640 ]
p: 1.0
- ValidatePolygons:
- ShrinkBinaryMap:
min_text_size: 8
shrink_ratio: 0.55
- BorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns: [ 'image', 'binary_map', 'mask', 'thresh_map', 'thresh_mask' ] #'img_path']
# output_columns: ['image'] # for debug op performance
net_input_column_index: [0] # input indices for network forward func in output_columns
label_column_index: [1, 2, 3, 4] # input indices marked as label

loader:
shuffle: True
batch_size: 16
drop_remainder: True
num_workers: 20

eval:
ckpt_load_path: dbnet_ccpd/best.ckpt
dataset_sink_mode: False
dataset:
type: DetDataset
dataset_root: path/to/DBNet_DataSets
data_dir: val/images
label_file: val/val_det_gt.txt
sample_ratio: 1.0
transform_pipeline:
- DecodeImage:
img_mode: RGB
to_float32: False
- DetLabelEncode:
- DetResize:
target_size: [ 1024, 1024 ] # h, w
keep_ratio: True
padding: True
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the labels for evaluation
output_columns: [ 'image', 'polys', 'ignore_tags', 'shape_list' ]
net_input_column_index: [0] # input indices for network forward func in output_columns
label_column_index: [1, 2] # input indices marked as label

loader:
shuffle: False
batch_size: 16 # TODO: due to dynamic shape of polygons (num of boxes varies), BS has to be 1
drop_remainder: True
num_workers: 2
46 changes: 46 additions & 0 deletions examples/license_plate_detection_and_recognition/generate_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
import os

from PIL import Image


def read_annotations(file_path):
annotations = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
image_name, description = line.strip().split("\t")
description_data = json.loads(description)
annotations.append((image_name, description_data))
return annotations


def crop_images(annotations, source_folder, target_folder, output_txt):
if not os.path.exists(target_folder):
os.makedirs(target_folder)
with open(output_txt, "w", encoding="utf-8") as out_file:
for image_name, data in annotations:
image_path = os.path.join(source_folder, image_name)
with Image.open(image_path) as img:
bbox = data[0]["bbox"]
x1, y1 = bbox[0]
x2, y2 = bbox[1]
cropped_img = img.crop((x1, y1, x2, y2))
cropped_image_name = f"{image_name}"
cropped_img.save(os.path.join(target_folder, cropped_image_name))
transcription = data[0]["transcription"]
out_file.write(f"{cropped_image_name}\t{transcription}\n")


def main():
datasets = ["train", "test", "val"]
for dataset in datasets:
annotations_file = f"path/to/DBNet_DataSets/{dataset}/{dataset}_det_gt.txt"
source_folder = f"path/to/DBNet_DataSets/{dataset}/images"
target_folder = f"path/to/SVTR_DataSets/{dataset}/"
output_txt = f"path/to/SVTR_DataSets/gt_{dataset}.txt"
annotations = read_annotations(annotations_file)
crop_images(annotations, source_folder, target_folder, output_txt)


if __name__ == "__main__":
main()
107 changes: 107 additions & 0 deletions examples/license_plate_detection_and_recognition/generate_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
provinces = [
"皖",
"沪",
"津",
"渝",
"冀",
"晋",
"蒙",
"辽",
"吉",
"黑",
"苏",
"浙",
"京",
"闽",
"赣",
"鲁",
"豫",
"鄂",
"湘",
"粤",
"桂",
"琼",
"川",
"贵",
"云",
"藏",
"陕",
"甘",
"青",
"宁",
"新",
"警",
"学",
"O",
]
alphabets = [
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"J",
"K",
"L",
"M",
"N",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
"O",
]
ads = [
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"J",
"K",
"L",
"M",
"N",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"O",
]

unique_characters = set(provinces + alphabets + ads)
unique_dict = {char: index for index, char in enumerate(sorted(unique_characters))}
with open("ccpd.txt", "w", encoding="utf-8") as file:
for char, index in unique_dict.items():
line = f"{char}:{index}\n"
file.write(line)
Loading
Loading