Skip to content

Commit

Permalink
🐛修复使用PaddleDetection训练出现的BUG
Browse files Browse the repository at this point in the history
  • Loading branch information
laugh12321 committed Oct 29, 2022
1 parent 3df3975 commit 924b344
Showing 1 changed file with 32 additions and 60 deletions.
92 changes: 32 additions & 60 deletions dair_v2x/pp_vehicle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
@File : pp_vehicle.py
@Version : 1.0
@Version : 2.0
@Author : laugh12321
@Contact : laugh12321@vip.qq.com
@Date : 2022/10/21 11:14:24
Expand All @@ -17,7 +17,12 @@


class Loader:
"""加载数据集标签信息"""
"""加载数据集标签信息
将 car, truck, van, bus 类别都设置为vehicle;
忽略 tricyclist, barrowlist, pedestrianignore,
carignore, otherignore, unknown_movable, unknown_unmovable 类别
"""

def __init__(self, data_dir: str, train_split: float = 0.85) -> None:
"""初始化
Expand All @@ -27,10 +32,8 @@ def __init__(self, data_dir: str, train_split: float = 0.85) -> None:
"""
self.data_dir = data_dir
self.__categories, self.__vhicles = self.__get_categories()
self.__data_info_path = os.path.join(data_dir,
"data_info.json") # 数据信息
self.__train_data_info, self.__val_data_info = self.__get_train_val_info(
train_split)
self.__data_info_path = os.path.join(data_dir, "data_info.json") # 数据信息
self.__train_data_info, self.__val_data_info = self.__get_train_val_info(train_split)

@property
def train_info(self) -> dict:
Expand All @@ -56,21 +59,11 @@ def vhicles(self) -> list:
def __get_categories() -> dict:
"""获取类别与id对应关系"""
__categories = [
"vehicle",
"pedestrian",
"cyclist",
"tricyclist",
"motorcyclist",
"barrowlist",
"vehicle", "pedestrian", "cyclist", "motorcyclist", "barrowlist"
]

return {_category: _id
for _id, _category in enumerate(__categories)}, [
"car",
"truck",
"van",
"bus",
]
for _id, _category in enumerate(__categories)}, ["car", "truck", "van", "bus"]

def __get_train_val_info(self, train_split) -> Tuple[dict, dict]:
"""分割训练集和验证集"""
Expand Down Expand Up @@ -115,56 +108,35 @@ def format2coco(self, data_info: dict, json_path: str) -> None:
ids (list): 图片ids
json_path (str): annotations json 保存路径
"""
coco_json = {"images": [], "annotations": [], "categories": []}

item_id = 0
for data in tqdm(data_info):
file_name = data["image_path"]
coco_json = {"images": [], "annotations": [], "categories": []}
for _info in tqdm(data_info):
file_name = _info["image_path"]
img_id, _ = os.path.splitext(os.path.basename(file_name))
annos_dir = os.path.join(self.data_dir,
data["label_camera_std_path"])
annos = self.__get_annotations(annos_dir)

image_dict = {
coco_json["images"].append({
"id": int(img_id),
"file_name": file_name,
"height": 1080,
"width": 1920,
"id": img_id,
}
coco_json["images"].append(image_dict)

for item in annos:
xywh = self.__bbox2xywh(item["2d_box"])
category = item["type"].lower()
if category in self.vhicles: # 将所有的车辆类别都设为vehicle
"height": 1080
})

annos = self.__get_annotations(os.path.join(self.data_dir, _info["label_camera_std_path"])) # 获取标注信息
for _anno in annos:
if category := _anno["type"].lower() in self.vhicles: # 将所有的车辆类别都设为vehicle
category = "vehicle"
category_id = self.categories.get(category)
if category_id is not None:
annotation_dict = {
"area": xywh[-2] * xywh[-1],
"iscrowd": 0,
if category_id := self.categories.get(category) is not None: # 获取类别id
xywh = self.__bbox2xywh(_anno["2d_box"]) # coco的bbox[xmin, ymin, width, height]
coco_json["annotations"].append({
"id": len(coco_json["annotations"]),
"image_id": int(img_id),
"bbox": xywh,
"category_id": category_id,
"id": item_id,
}
"bbox": xywh,
"area": xywh[-2] * xywh[-1],
"iscrowd": 0
})
item_id += 1

coco_json["annotations"].append(annotation_dict)
if category not in coco_json["categories"]:
coco_json["categories"].append(category)
categories_list = [{
"id": self.categories[category],
"name": category
} for category in coco_json["categories"]]

coco_json["categories"] = categories_list
with open(json_path, "w+", encoding="utf-8") as file:
json.dump(coco_json,
file,
indent=4,
sort_keys=False,
ensure_ascii=False)
coco_json["categories"] = [{"id": _id, "name": _name} for _name, _id in self.categories.items()] # 类别信息
json.dump(coco_json, open(json_path, "w+", encoding='utf-8'), indent=4, sort_keys=False, ensure_ascii=False) # 保存json

def processing(self) -> None:
"""处理进程"""
Expand Down

0 comments on commit 924b344

Please sign in to comment.