Skip to content

Commit

Permalink
fix numerous bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ladbaby committed Aug 27, 2024
1 parent dbcfa72 commit d738c98
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 176 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/data/full
/data/CROHME
/data/MyDataset
/other
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
11 changes: 6 additions & 5 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#数据路径
data_name = 'CROHME' # 模型名称,仅在保存的时候用到
vocab_path = './data/CROHME/vocab.json'
train_set_path = './data/CROHME/train.json'
val_set_path = './data/CROHME/val.json'
dataset_dir = "data/MyDataset"
data_name = 'MyDataset' # 模型名称,仅在保存的时候用到
vocab_path = 'data/MyDataset/vocab.txt'
train_set_path = './data/small/train.json'
val_set_path = './data/small/val.json'


# 模型参数
Expand All @@ -19,7 +20,7 @@

# 训练参数
start_epoch = 0
epochs = 250 # 不触发早停机制时候最大迭代次数
epochs = 30 # 不触发早停机制时候最大迭代次数
epochs_since_improvement = 0 # 用于跟踪在验证集上分数没有提高的迭代次数
batch_size = 1 #训练解批大小
test_batch_size = 2 #验证集批大小
Expand Down
95 changes: 91 additions & 4 deletions model/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import os
from pathlib import Path

import torchvision
import torch
import json
import cv2
import numpy as np
from scipy.ndimage import zoom

from config import vocab_path,buckets
from torch.utils.data import Dataset
from model.utils import load_json

vocab = load_json(vocab_path)
with open(vocab_path) as f:
words = f.readlines()
words.append("<start>")
words.append("<end>")
vocab = {value.strip(): index + 1 for index, value in enumerate(words)}
vocab["<pad>"] = 0

def get_new_size(old_size, buckets=buckets,ratio = 2):
"""Computes new size from buckets
Expand Down Expand Up @@ -56,6 +66,7 @@ def label_transform(text,start_type = '<start>',end_type = '<end>',pad_type = '<
text = [start_type] + text + [end_type]
# while len(text)<max_len:
# text += [pad_type]
# print(f"{vocab=}")
text = [i for i in map(lambda x:vocab[x],text)]
return text
# return torch.LongTensor(text)
Expand All @@ -69,10 +80,86 @@ def img_transform(img,size,ratio = 1):
to_tensor = torchvision.transforms.ToTensor()
return to_tensor(new_im)

class MyDataset(Dataset):
def __init__(
self,
dataset_dir,
img_transform=img_transform,
label_transform = label_transform,
ratio = 2,
is_train = True
):
self.img_transform = img_transform # 传入图片预处理
self.label_transform = label_transform # 传入图片预处理
self.ratio = ratio#下采样率

# list all npy files under the directory
npy_files_list = []
for file_name in os.listdir(dataset_dir):
if file_name.endswith(".npy"):
npy_files_list.append(file_name)
split_dict = {
"train": (0, 0.9),
"eval": (0.9, 1),
}
split_tuple = split_dict["train"] if is_train else split_dict["eval"]
n_files = len(npy_files_list)
self.npy_files_list = npy_files_list[int(split_tuple[0] * n_files): int(split_tuple[1] * n_files)]
self.n_files = len(self.npy_files_list)
print(f"files {self.npy_files_list} are used as {'train' if is_train else 'eval'} set")

# determine the number of samples as delimiters; n files has n+1 delimiters
self.n_samples_list = [0]
if len(self.npy_files_list) == 0:
print(f"No .npy file found under the directory {Path(dataset_dir)}")
exit(1)
for file_name in self.npy_files_list:
temp_images_and_labels: list[dict] = np.load(Path(dataset_dir) / file_name, allow_pickle=True)
self.n_samples_list.append(self.n_samples_list[-1] + len(temp_images_and_labels))
del temp_images_and_labels

# record the current loaded file index in self.npy_file_list
self.current_file_idx = 0
# init with the first file
self.images_and_labels: list[dict] = np.load(Path(dataset_dir) / self.npy_files_list[0], allow_pickle=True)

def __getitem__(self, idx):
# WARNING: the codes assume shuffle=False

# whether to load next .npy file
if idx >= self.n_samples_list[self.current_file_idx + 1]:
self.images_and_labels = np.load(Path(self.dataset_dir) / self.npy_files_list[self.current_file_idx + 1], allow_pickle=True)
self.current_file_idx += 1

idx_in_list = idx - self.n_samples_list[self.current_file_idx]
image = self.images_and_labels[idx_in_list]["image"]
if image.shape[-1] == 3:
# convert RGB to grayscale and normalize
# (w, h, c) -> (w, h)
image = (0.299 * image[:, :, 0] + 0.587 * image[:, :, 1] + 0.114 * image[:, :, 2]) / 255
# Perform downsampling using scipy's zoom function
image = zoom(image, 1 / self.ratio, order=1) # order=1 corresponds to bilinear interpolation
# (w, h) -> (1, w, h)
image = torch.tensor(image).float().unsqueeze(0)
label: str = self.images_and_labels[idx_in_list]["label"]
label_list = self.label_transform(label)

# reset the counter and init between epochs
if idx == self.n_samples_list[-1] - 1:
self.current_file_idx = 0
self.images_and_labels = np.load(Path(self.dataset_dir) / self.npy_files_list[0], allow_pickle=True)

return image, torch.LongTensor(label_list), torch.tensor([len(label_list)])

def __len__(self):
return self.n_samples_list[-1]

class formuladataset(object):
#公式数据集,负责读取图片和标签,同时自动对进行预处理
#:param json_path 包含图片文件名和标签的json文件
#:param pic_transform,label_transform分别是图片预处理和标签预处理(主要是padding)
'''
公式数据集,负责读取图片和标签,同时自动对进行预处理
:param json_path 包含图片文件名和标签的json文件
:param pic_transform,label_transform分别是图片预处理和标签预处理(主要是padding)
'''
def __init__(self, data_json_path, img_transform=img_transform,label_transform = label_transform,ratio = 2,batch_size = 2):
self.img_transform = img_transform # 传入图片预处理
self.label_transform = label_transform # 传入图片预处理
Expand Down
10 changes: 6 additions & 4 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def add_timing_signal_nd(self, x, min_timescale=1.0, max_timescale=1.0e4):
signal = signal.unsqueeze(0)
for _ in range(num_dims - 1 - dim): # 1, 0
signal = signal.unsqueeze(-2)
x += signal # [1, 14, 1, 512]; [1, 1, 14, 512]
# don't use +=, or the in-place calculation will raise error in backward
x = x + signal # [1, 14, 1, 512]; [1, 1, 14, 512]
return x

class Attention(nn.Module):
Expand Down Expand Up @@ -231,8 +232,9 @@ def forward(self, encoder_out, encoded_captions, caption_lengths,p = 1):
# Sort input data by decreasing lengths; why? apparent below
caption_lengths, sort_ind = caption_lengths.sort(dim=0, descending=True)
# print('sort_ind',sort_ind,'encoder_out',encoder_out.shape,'encoder_captions',encoded_captions.shape)
encoder_out = encoder_out[sort_ind]
encoded_captions = encoded_captions[sort_ind]
encoder_out = encoder_out[sort_ind][:, 0]
encoded_captions = encoded_captions[sort_ind][:, 0]
# encoded_captions = torch.stack([encoded_captions for _ in range(num_pixels)], dim=2)

# Embedding
embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim)
Expand All @@ -243,7 +245,7 @@ def forward(self, encoder_out, encoded_captions, caption_lengths,p = 1):

# 我们一旦生成了<end>就已经完成了解码
# 因此需要解码的长度实际是 lengths - 1
decode_lengths = (caption_lengths - 1).tolist()
decode_lengths = caption_lengths - 1
# 新建两个张量用于存放 word predicion scores and alphas
predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)
Expand Down
24 changes: 24 additions & 0 deletions model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,30 @@
import cv2
import torch

def collate_fn(batch_images):
max_width, max_height, max_length = 0, 0, 0
batch, channel = len(batch_images), batch_images[0][0].shape[0]
proper_items = []
for item in batch_images:
if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[2] * max_height > 1600 * 320:
continue
max_height = item[0].shape[1] if item[0].shape[1] > max_height else max_height
max_width = item[0].shape[2] if item[0].shape[2] > max_width else max_width
max_length = item[1].shape[0] if item[1].shape[0] > max_length else max_length
proper_items.append(item)

images, image_masks = torch.zeros((len(proper_items), channel, max_height, max_width)), torch.zeros((len(proper_items), 1, max_height, max_width))
labels, labels_masks = torch.zeros((len(proper_items), max_length)).long(), torch.zeros((len(proper_items), max_length))

for i in range(len(proper_items)):
_, h, w = proper_items[i][0].shape
images[i][:, :h, :w] = proper_items[i][0]
image_masks[i][:, :h, :w] = 1
l = proper_items[i][1].shape[0]
labels[i][:l] = proper_items[i][1]
labels_masks[i][:l] = 1
return images, labels, batch_images

def load_json(path):
with open(path,'r')as f:
data = json.load(f)
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Distance
nltk
1 change: 1 addition & 0 deletions train.ps1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python train.py
Loading

0 comments on commit d738c98

Please sign in to comment.