Skip to content

Commit

Permalink
typo
Browse files Browse the repository at this point in the history
  • Loading branch information
mcshih committed Feb 25, 2022
1 parent df921be commit 21e613b
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 17 deletions.
4 changes: 2 additions & 2 deletions IAM_brush.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from tqdm import tqdm

for root, dirnames, filenames in os.walk("/home/user/ACM/shih/IAM/words/"):
for root, dirnames, filenames in os.walk("/mnt/baf69772-7c2f-4570-a192-06c62f849660/data/shih/IAM/sentences/"):
pbar = tqdm(filenames)
for filename in pbar:
path = os.path.join(root, filename)
Expand All @@ -20,7 +20,7 @@
if img[i,j,0] > 200 and img[i,j,1] > 200 and img[i,j,2] > 200:
result[i,j,3] = 0

root = root.replace("/words/", "/words_a/")
root = root.replace("/sentences/", "/sentences_a/")
pbar.set_description("Processing %s" % os.path.join(root, filename))
cv2.imwrite(os.path.join(root, filename), result, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
except:
Expand Down
20 changes: 13 additions & 7 deletions handwriting_stamp_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@

# Word List
word_file_list = []
for root, dirnames, filenames in os.walk("/home/user/ACM/shih/IAM/words_a/"):
for root, dirnames, filenames in os.walk("/mnt/baf69772-7c2f-4570-a192-06c62f849660/data/shih/IAM/words_a/"):
for filename in filenames:
path = os.path.join(root, filename)
word_file_list.append(path)

hk_path = "/mnt/baf69772-7c2f-4570-a192-06c62f849660/data/shih/HK_dataset/img_a"
files = os.listdir(hk_path)
for file in files:
path = os.path.join(hk_path, file)
word_file_list.append(path)

def images_process(image_path, image_final_name):
result_path = "/home/user/ACM/shih/DDI-100/my_dataset/"
result_path = "/mnt/baf69772-7c2f-4570-a192-06c62f849660/data/shih/DDI-100/my_train_dataset/"

img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
canvas = np.copy(img)
# print(img.shape)
canvas = cv2.cvtColor(img, cv2.COLOR_GRAY2BGRA)

height, width = img.shape[:2]
N_word = random.randint(1, 20)
N_selected_words = random.sample(word_file_list, N_word)
Expand Down Expand Up @@ -67,13 +73,13 @@ def images_process(image_path, image_final_name):

#images_process("/home/user/ACM/shih/DDI-100/dataset_v1.3/01/orig_texts/0.png")

label_file = "/home/user/ACM/shih/DDI-100/05_my_labels.json"
label_file = "/mnt/baf69772-7c2f-4570-a192-06c62f849660/data/shih/DDI-100/04_gen_my_labels.json"
labels_dict = {}
origin_img_folder = "/home/user/ACM/shih/DDI-100/dataset_v1.3/05/orig_texts/"
origin_img_folder = "/mnt/baf69772-7c2f-4570-a192-06c62f849660/data/shih/DDI-100/dataset_v1.3/04/gen_imgs/"
pbar = tqdm(os.listdir(origin_img_folder))
for doc in pbar:
#print(doc)
save_name = "05_"+doc
save_name = "04_"+doc
labels_dict[save_name] = images_process(os.path.join(origin_img_folder,doc), save_name)
with open(label_file, "w") as outfile:
json.dump(labels_dict, outfile, indent = 4)
Expand Down
82 changes: 74 additions & 8 deletions test_my_dataset.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,105 @@
from operator import gt
import cv2
import numpy as np
import os

os.environ['USE_TORCH'] = '1'

import torch
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from tqdm import tqdm
from doctr.models.predictor import OCRPredictor
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.models.preprocessor import PreProcessor
from doctr.models import crnn_vgg16_bn, db_resnet50
from doctr.models import crnn_vgg16_bn, db_resnet50, ocr_predictor
from doctr.io import DocumentFile
from doctr.utils.visualization import visualize_page
from doctr.utils.metrics import LocalizationConfusion, OCRMetric, TextMatch

def _pct(val):
return "N/A" if val is None else f"{val:.2%}"

def xml_parser(xmL_file):
tree = ET.parse(xmL_file)
root = tree.getroot()
total_height, total_width = int(root.attrib['height']), int(root.attrib['width'])

gt_boxes = []
for word in root.iter('word'):
x_min, x_max = total_width, 0
y_min, y_max = total_height, 0
for cmp in word:
x_left, x_right = int(cmp.attrib['x']), int(cmp.attrib['x'])+int(cmp.attrib['width'])
y_up, y_down = int(cmp.attrib['y']), int(cmp.attrib['y'])+int(cmp.attrib['height'])
x_min = min(x_left, x_min)
x_max = max(x_right, x_max)
y_min = min(y_up, y_min)
y_max = max(y_down, y_max)
#print("{} bbox:[ {}, {}, {}, {}]".format(word.attrib['text'],x_min, y_min, x_max, y_max))
gt_boxes.append([x_min, y_min, x_max, y_max])
return gt_boxes

device = torch.device("cuda")
torch.cuda.set_device(0)

def pred_boxes_list(result):
pred_boxes = []
height, width = result.pages[0].dimensions
for block in result.pages[0].blocks:
for line in block.lines:
for word in line.words:
(a, b), (c, d) = word.geometry
pred_boxes.append([int(a * width), int(b * height), int(c * width), int(d * height)])
return pred_boxes

# Instantiate your model here
det_model = db_resnet50(pretrained=False)
reco_model = crnn_vgg16_bn(pretrained=True)
det_params = torch.load("/home/user/ACM/shih/doctr/IMGUR5K_shrink.pt", map_location="cpu")
det_params = torch.load("/mnt/baf69772-7c2f-4570-a192-06c62f849660/data/shih/doctr/baseline_mergedataset_2.pt", map_location='cpu')
det_model.load_state_dict(det_params)

det_predictor = DetectionPredictor(PreProcessor((1024, 1024), batch_size=1), det_model)
reco_predictor = RecognitionPredictor(PreProcessor((32, 128), preserve_aspect_ratio=True, batch_size=32), reco_model)

predictor = OCRPredictor(det_predictor, reco_predictor)
predictor.cuda(0)

imgs_folder = "/home/user/ACM/shih/FUNSD/dataset/testing_data/images/"
save_folder = "/home/user/ACM/shih/FUNSD/demo/"
pretrained_model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True)
# predictor = predictor.cuda()
det_metric = LocalizationConfusion(iou_thresh=0.5)
det_pretrain_metric = LocalizationConfusion(iou_thresh=0.5)

imgs_folder = "/mnt/baf69772-7c2f-4570-a192-06c62f849660/data/shih/IAM/forms/"
xml_folder = "/mnt/baf69772-7c2f-4570-a192-06c62f849660/data/shih/IAM/xml/"
save_folder = "/mnt/baf69772-7c2f-4570-a192-06c62f849660/data/shih/demo(IMGUR5K_shrink)/"

files = os.listdir(imgs_folder)
pbar = tqdm(files)
for file in pbar:
for idx, file in enumerate(pbar):
'''
if idx > 0:
break
'''
pbar.set_description("Processing %s" % file)
img = DocumentFile.from_images(imgs_folder + file)
xml_file = file.replace('.png','.xml')
img = DocumentFile.from_images(os.path.join(imgs_folder, file))
result = predictor(img)
#print(type(img))
pretrained_result = pretrained_model(img)
pred_boxes = pred_boxes_list(result)
pred_pretrained_boxes = pred_boxes_list(pretrained_result)
#print(pred_boxes)
gt_boxes = xml_parser(os.path.join(xml_folder+xml_file))
#print(gt_boxes)
det_metric.update(np.asarray(gt_boxes), np.asarray(pred_boxes))
det_pretrain_metric.update(np.asarray(gt_boxes), np.asarray(pred_pretrained_boxes))
# save file
'''
output = visualize_page(result.pages[0].export(), np.asarray(img[0]))
output.savefig(save_folder + file)
plt.close(output)
'''
recall, precision, mean_iou = det_metric.summary()
print(f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}")
recall, precision, mean_iou = det_pretrain_metric.summary()
print(f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}")

0 comments on commit 21e613b

Please sign in to comment.