Pathological image analysis is a crucial field in computer vision. Due to the annotation scarcity in the pathological field, recently, most of the works have leveraged self-supervised learning (SSL) trained on unlabeled pathological images, hoping to mine the representation effectively. However, there are two core defects in current SSL-based pathological pre-training: (1) they do not explicitly explore the essential focuses of the pathological field, and (2) they do not effectively bridge with and thus take advantage of the knowledge from natural images. To explicitly address them, we propose our large-scale PuzzleTuning framework, containing the following innovations. Firstly, we define three task focuses that can effectively bridge knowledge of pathological and natural domain: appearance consistency, spatial consistency, and restoration understanding. Secondly, we devise a novel multiple puzzle restoring task, which explicitly pre-trains the model regarding these focuses. Thirdly, we introduce an explicit prompt-tuning process to incrementally integrate the domain-specific knowledge. It builds a bridge to align the large domain gap between natural and pathological images. Additionally, a curriculum-learning training strategy is designed to regulate task difficulty, making the model adaptive to the puzzle restoring complexity. Experimental results show that our PuzzleTuning framework outperforms the previous state-of-the-art methods in various downstream tasks on multiple datasets.
Samples illustrate the focuses and relationships in pathological images. They are pancreatic liquid samples (a and b) and colonic epithelium tissue samples (c and d) of normal (a and c) and cancer conditions (b and d). The patches of them are numbered from 1 to 9. Grouping the patches from each image as a bag, after intermixing patches among them, the three pathological focuses of appearance consistency, spatial consistency, and restoration understanding are highlighted. Overview of PuzzleTuning. Three steps are designed in PuzzleTuning: 1) Puzzle making, where image batch are divided into bags of patches and fix-position and relation identity are randomly assigned. The relation patches are then in-place shuffled with each other, making up the puzzle state. 2) Puzzle understanding, where puzzles regarding grouping, junction, and restoration relationships are learned by prompt tokens attached to the encoder. Through the prompt tokens, the pathological focuses are explicitly seamed with general vision knowledge. 3) Puzzle restoring, where the decoder restores the relation patches with position patches as hint, under SSL supervision against original images.from huggingface_hub import hf_hub_download
import torch
from PuzzleTuning.Backbone.GetPromptModel import build_promptmodel
# Define the repo ID
repo_id = "Tianyinus/PuzzleTuning_VPT"
# Download the base state dictionary file
base_state_dict_path = hf_hub_download(repo_id=repo_id, filename="PuzzleTuning/Archive/ViT_b16_224_Imagenet.pth")
# Download the prompt state dictionary file
prompt_state_dict_path = hf_hub_download(repo_id=repo_id,
filename="PuzzleTuning/Archive/ViT_b16_224_timm_PuzzleTuning_SAE_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth")
# Load these weights into your model
base_state_dict = torch.load(base_state_dict_path)
prompt_state_dict = torch.load(prompt_state_dict_path)
# Build your model using the loaded state dictionaries
model = build_promptmodel(prompt_state_dict=prompt_state_dict,
base_state_dict=base_state_dict,
num_classes=0)
we have updated the pre-trained weight of PuzzleTuning and all counterparts at
https://drive.google.com/file/d/1-mddejIdCRP5AscnlWAyEcGzfgBIRCSf/view?usp=share_link
we have updated a demo for iullustration at
https://github.com/sagizty/PuzzleTuning/blob/main/PuzzleTuning%20Colab%20Demo.ipynb
python -m torch.distributed.launch --nproc_per_node=8 --nnodes 1 --node_rank 0 PuzzleTuning.py --DDP_distributed --batch_size 64 --group_shuffle_size 8 --blr 1.5e-4 --epochs 2000 --accum_iter 2 --print_freq 5000 --check_point_gap 100 --input_size 224 --warmup_epochs 100 --pin_mem --num_workers 32 --strategy loop --PromptTuning Deep --basic_state_dict /home/saved_models/ViT_b16_224_Imagenet.pth --data_path /home/datasets/All
https://github.com/zhanglab2021/CPIA_Dataset