Skip to content

pokeshun96/CustomImageFolder

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 

Repository files navigation

CustomImageFolder

pytorchで画像分類タスクを行うときに、不正解データの画像ファイルを出力したい!
ImageFolder クラスを使っていて、ファイルパスも返すように DataLoader を調整するために、カスタムの Dataset クラスを作成します.

from torchvision.datasets import ImageFolder

class CustomImageFolder(ImageFolder):
    def __getitem__(self, index):
        # 元の ImageFolder クラスの __getitem__ メソッドを呼び出す
        original_tuple = super(CustomImageFolder, self).__getitem__(index)
        # ファイルパスを取得
        path = self.imgs[index][0]
        # 画像データ、ラベル、ファイルパスを含むタプルを返す
        return (original_tuple + (path,))

そして、この CustomImageFolder クラスを使用してデータセットを作成します.

# カスタムデータセットのインスタンスを作成
image_datasets = {
    x: CustomImageFolder(os.path.join(data_dir, x), data_transforms[x])
    for x in ['train', 'val']
}

# DataLoader の作成
dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                   shuffle=True, num_workers=4, pin_memory=True)
    for x in ['train', 'val']
}

これで、DataLoader から得られる各バッチは (inputs, labels, paths) の形式になります。
これを使って、train_model 関数内で不正解の画像のファイルパスを取得して出力することができます.

# 検証フェーズ
if phase == 'val':
    # ...
    for inputs, labels, paths in dataloaders['val']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        # ...
        incorrect_indices = (preds != labels).nonzero(as_tuple=True)[0]
        for idx in incorrect_indices:
            incorrect_samples.append(paths[idx])
    # ...

# 検証終了後に不正解データのファイル名の出力
print("Incorrect samples:")
for file_path in incorrect_samples:
    print(file_path)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published