Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "wip/flasht5/flashT5"]
path = wip/flasht5/flashT5
url = https://github.com/catie-aq/flashT5.git
32 changes: 32 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel

# Set working directory
WORKDIR /app

# Install git and other dependencies
RUN apt-get update && apt-get install -y \
git \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*

# Copy requirements first for better caching
COPY requirements.txt .

# Install Python dependencies
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir $(grep -v "flash_attn" requirements.txt) && \
pip install --no-cache-dir flash_attn>=2.5.6

# Copy the rest of the code
COPY . .

# Initialize git submodules with a safer approach
RUN git config --global advice.detachedHead false && \
if [ -f .gitmodules ]; then \
git submodule init && \
git submodule update --remote || echo "Warning: Some submodules may not have been updated correctly"; \
fi

# Set the default command
CMD ["bash"]
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,29 @@ Note:

## Upcoming releases
A Fast version of ANKH is in progress. It is functional but is still currently native attention, we are waiting for bias gradient support in [FlexAttention](https://pytorch.org/blog/flexattention/).

## Docker Setup

To use the Docker container (for testing):

```bash
# Build and start the container
docker-compose up -d

# Execute commands inside the container
docker-compose exec esmplusplus bash

# Or run specific Python scripts
docker-compose exec esmplusplus python wip/flasht5/test_t5_flash_attention.py
docker-compose exec esmplusplus python wip/flasht5/flashT5/benchmarks/bench_fa2_bias.py
docker-compose exec esmplusplus python -m wip.flasht5.flashT5.benchmarks.bench_fa2_bias
```

The Docker container automatically:
- Installs all required dependencies from requirements.txt
- Initializes the Git submodules
- Provides GPU support (if available and Docker is configured for GPU access)

### Prerequisites for GPU support
- Install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)
- Ensure Docker is configured to use the NVIDIA runtime
19 changes: 19 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
version: '3'

services:
esmplusplus:
build:
context: .
dockerfile: Dockerfile
volumes:
- .:/app
tty: true
stdin_open: true
# If you need GPU support
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
19 changes: 13 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
torch>=2.6.0
matplotlib>=3.5.0
torch>=2.4.0
transformers>=4.47.0
numpy>=1.26.2
datasets>=2.16.0
safetensors>=0.4.2
accelerate>=1.1.0
evaluate>=0.4.0
peft>=0.5.0
triton>=3.0.0
flash_attn>=2.5.6
einops
numpy>=1.26.4
matplotlib>=3.5.0
esm
datasets>=2.14.0
scikit-learn>=1.0.0
scipy>=1.7.0
seaborn>=0.12.0
peft>=0.5.0
accelerate>=1.1.0
attr>=0.3.2
clearml>=1.14.2
tqdm>=4.66.2
50 changes: 50 additions & 0 deletions wip/flasht5/convert_ankh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import re
import os
from safetensors import safe_open
from safetensors.torch import save_file


def convert_ankh(current_path, save_path):
tensors = {}
with safe_open(current_path, framework="pt") as f:
for k in f.keys():
new_k = re.sub(".layer.*.SelfAttention.q", ".self_attention_layer.self_attention.Wq", k)
new_k = re.sub(".layer.*.SelfAttention.k", ".self_attention_layer.self_attention.Wk", new_k)
new_k = re.sub(".layer.*.SelfAttention.v", ".self_attention_layer.self_attention.Wv", new_k)
new_k = re.sub(".layer.*.SelfAttention.o", ".self_attention_layer.self_attention.o", new_k)
new_k = re.sub(".layer.*.EncDecAttention.q", ".cross_attention_layer.cross_attention.Wq", new_k)
new_k = re.sub(".layer.*.EncDecAttention.k", ".cross_attention_layer.cross_attention.Wk", new_k)
new_k = re.sub(".layer.*.EncDecAttention.v", ".cross_attention_layer.cross_attention.Wv", new_k)
new_k = re.sub(".layer.*.EncDecAttention.o", ".cross_attention_layer.cross_attention.o", new_k)
new_k = re.sub(".layer.*.SelfAttention.relative_attention_bias.", ".self_attention_layer.self_attention.pe_encoding.relative_attention_bias.", new_k)
new_k = new_k.replace(".layer.0.layer_norm.", ".self_attention_layer.layer_norm.")
if "encoder" in new_k:
new_k = new_k.replace(".layer.1.layer_norm.", ".ff_layer.layer_norm.")
else:
new_k = new_k.replace(".layer.1.layer_norm.", ".cross_attention_layer.layer_norm.")
new_k = new_k.replace(".layer.2.layer_norm.", ".ff_layer.layer_norm.")
new_k = re.sub(".layer.*.DenseReluDense.", ".ff_layer.", new_k)
new_k = new_k.replace(".wi_", ".act.wi_")
tensors[new_k] = f.get_tensor(k).clone()

save_file(tensors, save_path)


if __name__ == "__main__":
import shutil
from transformers import T5EncoderModel

model_path_base = 'Synthyra'
model_path = os.path.join(model_path_base, 'ANKH_base')
save_path_base = os.path.join(model_path_base, 'ANKH_base_flash')
save_path = os.path.join(save_path_base, 'model.safetensors')

if os.path.exists(model_path_base):
shutil.rmtree(model_path_base)

model = T5EncoderModel.from_pretrained('Synthyra/ANKH_base')
model.save_pretrained(model_path, push_to_hub=False)

current_safetensors = os.path.join(model_path, 'model.safetensors')
os.makedirs(os.path.dirname(save_path_base), exist_ok=True)
convert_ankh(current_safetensors, save_path)
1 change: 1 addition & 0 deletions wip/flasht5/flashT5
Submodule flashT5 added at b5d08a
Loading