Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8b0fcd0
Update json_to_dataset.py
Mr47121836 May 15, 2023
0fc68af
Update requirements.txt
Mr47121836 May 15, 2023
834eeb4
Update voc_annotation.py
Mr47121836 May 17, 2023
c7f4017
Delete 1.jpg
Mr47121836 May 17, 2023
1d03504
Create README.md
Mr47121836 May 17, 2023
0c74c44
Update requirements.txt
Mr47121836 May 17, 2023
9ebc012
Update dataloader.py
Mr47121836 May 17, 2023
301bf46
Delete 1.jpg
Mr47121836 May 17, 2023
298c2ab
Delete 1.json
Mr47121836 May 17, 2023
93aaf29
Delete 1.png
Mr47121836 May 17, 2023
ca04db8
Create README.md
Mr47121836 May 17, 2023
f902ac1
Update README.md
Mr47121836 May 17, 2023
41d45aa
Create README.md
Mr47121836 May 17, 2023
7d780df
Update README.md
Mr47121836 May 17, 2023
9b6360e
Update callbacks.py
Mr47121836 May 18, 2023
560a6eb
Update train.py
Mr47121836 May 18, 2023
e1d0d1f
Update train.py
Mr47121836 May 18, 2023
9703bff
Update train.py
Mr47121836 May 18, 2023
ee21e1a
Update train.py
Mr47121836 May 18, 2023
e7d643c
Update requirements.txt
Mr47121836 May 22, 2023
9c266da
Add files via upload
Mr47121836 May 24, 2023
90f66c0
Update predict.py
Mr47121836 May 25, 2023
92ade43
Update predict.py
Mr47121836 May 25, 2023
38d83b3
Update README.md
Mr47121836 May 27, 2023
ec17ffc
Update imgs_resize.py
Mr47121836 May 27, 2023
a268762
Update imgs_resize.py
Mr47121836 May 27, 2023
5976335
Update README.md
Mr47121836 May 27, 2023
c9e1dad
Update README.md
Mr47121836 May 27, 2023
8bcb4c8
Update predict.py
Mr47121836 May 27, 2023
647abec
Update train.py
Mr47121836 May 27, 2023
02c4bac
Update train.py
Mr47121836 May 27, 2023
adbdb72
Update predict.py
Mr47121836 May 28, 2023
a59e39b
添加Jupyter文件,包含训练的过程
Mr47121836 May 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ VOC拓展数据集的百度网盘如下:
2、运行train.py进行训练,默认参数已经对应voc数据集所需要的参数了。

#### 二、训练自己的数据集
1、本文使用VOC格式进行训练。
2、训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的SegmentationClass中。
3、训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
4、在训练前利用voc_annotation.py文件生成对应的txt。
5、注意修改train.py的num_classes为分类个数+1。
6、运行train.py即可开始训练。
1、本文使用VOC格式进行训练。
2、如果需要修改图片尺寸请使用imgs_resize.py文件进行修改。
3、训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的SegmentationClass中。
4、训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
5、在训练前利用voc_annotation.py文件生成对应的txt。
6、注意修改train.py的num_classes为分类个数+1。
7、运行train.py即可开始训练。

#### 三、训练医药数据集
1、下载VGG的预训练权重到model_data下面。
Expand Down
2,443 changes: 2,443 additions & 0 deletions Unet_pytorch_2023_5_27.ipynb

Large diffs are not rendered by default.

Binary file removed datasets/JPEGImages/1.jpg
Binary file not shown.
1 change: 1 addition & 0 deletions datasets/JPEGImages/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Binary file removed datasets/SegmentationClass/1.png
Binary file not shown.
1 change: 1 addition & 0 deletions datasets/SegmentationClass/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#mask文件存放位置
Binary file removed datasets/before/1.jpg
Binary file not shown.
135 changes: 0 additions & 135 deletions datasets/before/1.json

This file was deleted.

2 changes: 2 additions & 0 deletions datasets/before/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

存放jpg和json
43 changes: 43 additions & 0 deletions imgs_resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
import argparse
from tqdm import tqdm
from PIL import Image

def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument('--input_path', default='datasets/SegmentationClass',
help='input path')
parser.add_argument('--output_path', default='datasets/ReSize_SegmentationClass',
help='number of total epochs to run')
config = parser.parse_args()

return config

"""程序功能是对图片进行裁剪,把图片多余的部分进行裁剪,留下含有数据的部分"""

if __name__ == '__main__':

config = vars(parse_args())
imgs = os.listdir(config['input_path'])
if not os.path.exists(config['output_path']):
os.makedirs(config['output_path'])
for img in tqdm(imgs):
if img.endswith("jpg") or img.endswith("png"):
left = 645
top = 430
right = 645 + 1024
bottom = 430 + 1024
#left = math.floor(width/2) - 512
#right = math.floor(width/2) + 512
#top = math.floor(height/2) - 512
#bottom = math.floor(height/2) + 512
try:
im = Image.open(os.path.join(config['input_path'], img))
width, height = im.size
im1 = im.crop([left, top, right, bottom])
cut_name = os.path.join(config['output_path'], img)
im1.save(cut_name)
except RuntimeError as e:
print(e)
print("转换完成!!")
4 changes: 2 additions & 2 deletions json_to_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
if __name__ == '__main__':
jpgs_path = "datasets/JPEGImages"
pngs_path = "datasets/SegmentationClass"
classes = ["_background_","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
# classes = ["_background_","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
# classes = ["_background_","cat","dog"]

classes = ["_background_","quesun","youwu","huahen","maocao"]
count = os.listdir("./datasets/before/")
for i in range(0, len(count)):
path = os.path.join("./datasets/before", count[i])
Expand Down
13 changes: 11 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
# count、name_classes仅在mode='predict'时有效
#-------------------------------------------------------------------------#
count = False
name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
# name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
# name_classes = ["background","cat","dog"]
name_classes = ["_background_","quesun","youwu","huahen","maocao"]
#----------------------------------------------------------------------------------------------------------#
# video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
# 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
Expand Down Expand Up @@ -90,14 +91,22 @@
'''
while True:
img = input('Input image filename:')
img_name = img.split("/")[-1]
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
r_image = unet.detect_image(image, count=count, name_classes=name_classes)
r_image.show()
#r_image.show()
plt.figure(figsize=(24.48, 20.48)) # 设置窗口大小
plt.suptitle('predict result') # 图片名称
plt.subplot(1, 2, 1), plt.title('Source: ' + img_name)
plt.imshow(image), plt.axis('off')
plt.subplot(1, 2, 2), plt.title('Result: ' + img_name)
plt.imshow(r_image), plt.axis('off')
plt.show()

elif mode == "video":
capture=cv2.VideoCapture(video_path)
Expand Down
19 changes: 10 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
scipy==1.2.1
numpy==1.17.0
matplotlib==3.1.2
opencv_python==4.1.2.30
torch==1.2.0
torchvision==0.4.0
tqdm==4.60.0
Pillow==8.2.0
h5py==2.10.0
scipy
numpy
matplotlib
opencv_python
torch
torchvision
tqdm
Pillow
h5py

12 changes: 6 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
# num_classes 训练自己的数据集必须要修改的
# 自己需要的分类个数+1,如2+1
#-----------------------------------------------------#
num_classes = 21
num_classes = 5
#-----------------------------------------------------#
# 主干网络选择
# vgg
Expand Down Expand Up @@ -102,7 +102,7 @@
#-----------------------------------------------------#
# input_shape 输入图片的大小,32的倍数
#-----------------------------------------------------#
input_shape = [512, 512]
input_shape = [1024, 1024]

#----------------------------------------------------------------------------------------------------------------------------#
# 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。
Expand Down Expand Up @@ -219,7 +219,7 @@
# 种类多(十几类)时,如果batch_size比较大(10以上),那么设置为True
# 种类多(十几类)时,如果batch_size比较小(10以下),那么设置为False
#------------------------------------------------------------------#
dice_loss = False
dice_loss = True
#------------------------------------------------------------------#
# 是否使用focal loss来防止正负样本不平衡
#------------------------------------------------------------------#
Expand All @@ -238,7 +238,7 @@
# keras里开启多线程有些时候速度反而慢了许多
# 在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。
#------------------------------------------------------------------#
num_workers = 4
num_workers = 0

#------------------------------------------------------#
# 设置用到的显卡
Expand Down Expand Up @@ -471,9 +471,9 @@
if distributed:
batch_size = batch_size // ngpus_per_node

gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=False,
drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler)
gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=False,
drop_last = True, collate_fn = unet_dataset_collate, sampler=val_sampler)

UnFreeze_flag = True
Expand Down
2 changes: 1 addition & 1 deletion utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_d
self.eval_flag = eval_flag
self.period = period

self.image_ids = [image_id.split()[0] for image_id in image_ids]
self.image_ids = [image_id.split('\n')[0] for image_id in image_ids]
self.mious = [0]
self.epoches = [0]
if self.eval_flag:
Expand Down
2 changes: 1 addition & 1 deletion utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __len__(self):

def __getitem__(self, index):
annotation_line = self.annotation_lines[index]
name = annotation_line.split()[0]
name = annotation_line.split("\n")[0]

#-------------------------------#
# 从文件中读取图像
Expand Down
4 changes: 2 additions & 2 deletions voc_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

print("Check datasets format, this may take a while.")
print("检查数据集格式是否符合要求,这可能需要一段时间。")
classes_nums = np.zeros([256], np.int)
classes_nums = np.zeros([256], np.int_)
for i in tqdm(list):
name = total_seg[i]
png_file_name = os.path.join(segfilepath, name)
Expand Down Expand Up @@ -95,4 +95,4 @@

print("JPEGImages中的图片应当为.jpg文件、SegmentationClass中的图片应当为.png文件。")
print("如果格式有误,参考:")
print("https://github.com/bubbliiiing/segmentation-format-fix")
print("https://github.com/bubbliiiing/segmentation-format-fix")