Skip to content

YongheeChoi/sae

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Sparse Autoencoder (SAE) 프로젝트

이 저장소는 CIFAR-10, CIFAR-100, TinyImageNet, ImageNet 등의 이미지 데이터셋을 활용해 거대한 잠재 공간(latent space)을 학습하는 심플한 Sparse Autoencoder(SAE) 실험용 코드베이스입니다. 모델은 이미지 입력/출력을 유지하면서 중앙에 하나의 활성화 계층을 두도록 설계되었으며, 노이즈 이미지를 번갈아 학습에 포함해 잠재 공간이 의미 있는 축으로 정렬되도록 돕습니다.

주요 특징

  • PyTorch 기반의 대용량 선형 오토인코더
  • YAML/CLI 구성으로 잠재 차원, 중간 차원, 활성화 함수, 드롭아웃 등 간단히 조정
  • 노이즈 배치를 교차 학습시켜 잠재 표현을 0 벡터로 제한하는 옵션
  • CUDA + AMP 학습 및 Gradient Clipping 지원
  • CIFAR-10/100, TinyImageNet, ImageNet 데이터 파이프라인 포함

환경 구성

conda env create -f environment.yml
conda activate try

데이터 준비

모든 데이터는 data_root (기본 ./data) 아래에 정리합니다.

  • CIFAR-10/100: 자동 다운로드 지원 (--download true)
  • TinyImageNet: 공식 사이트에서 tiny-imagenet-200 압축을 내려받아 data_root/tiny-imagenet-200에 압축을 해제합니다.
  • ImageNet: ImageNet 1k를 수동 배치합니다.
    data_root/
      imagenet/
        train/<class>/*.JPEG
        val/<class>/*.JPEG
    

구성 파일

configs/default.yaml 예시는 아래와 같습니다.

dataset: cifar10
image_size: 32
batch_size: 512
latent_dim: 8192
hidden_dim: 16384
activation: gelu
enable_noise: true
noise_latent_weight: 1.0
noise_recon_weight: 0.05
output_dir: ./outputs

필요에 따라 값을 수정하거나 --config 없이 직접 CLI 인자를 전달할 수 있습니다.

실행 방법

python main.py --config configs/default.yaml

주요 CLI 인자:

  • --dataset {cifar10,cifar100,tinyimagenet,imagenet}
  • --latent-dim, --hidden-dim, --activation, --dropout
  • --enable-noise {true,false}, --noise-latent-weight, --noise-recon-weight
  • --image-size, --batch-size, --num-workers

노이즈 학습 전략

--enable-noise true일 때 각 학습 스텝은 다음 순서로 진행됩니다.

  1. 실제 이미지 배치에 대한 재구성 + 희소성 손실
  2. 무작위 노이즈 배치를 인코딩하고, 잠재 벡터를 0으로 수렴시키는 손실
  3. 0 잠재 벡터를 디코딩할 때도 0 이미지(의미 없음)로 복원하도록 보조 손실 적용

이를 통해 잡음 데이터가 잠재 공간에 의미 없는 축으로 맵핑되도록 강제합니다.

프로젝트 구조

sae/
  __init__.py       # 패키지 초기화
  args.py           # YAML/CLI 파서
  data.py           # 데이터셋 및 Noise 데이터로더
  model.py          # Sparse Autoencoder 정의
  train.py          # 학습 루프 (노이즈 교차 학습 포함)
  utils.py          # 공통 유틸 (시드, 체크포인트 등)
configs/
  default.yaml      # 기본 설정
environment.yml     # Conda 환경
main.py             # 엔트리 포인트

로그 & 체크포인트

--output-dir (기본 ./outputs) 아래에 실행 시각 기준 디렉터리가 생성되고, 다음이 저장됩니다.

  • config.json: 실제 사용된 하이퍼파라미터 스냅샷
  • checkpoints/epoch-XXX.pt, checkpoints/best.pt

참고 사항

  • 실제 ImageNet, TinyImageNet 경로는 사용자가 직접 준비해야 합니다.
  • 학습 진행 중 tensorboardwandb를 추가로 연동해도 됩니다.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages