Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Align-Anything aims to align any modality large models (any-to-any models), incl

## 📣 News
* **`Coming Soon`** ⚡️⚡️⚡️ We plan to separate the evaluation component from align-anything and establish eval-anything as a dedicated repository for large-scale evaluation of any-to-any models. Meanwhile, align-anything will remain focused on the post-training alignment of any-to-any models.
* **[2025.03.12]** 🛠️🛠️🛠️ We have supported resume training for DPO and SFT, see [here](https://github.com/PKU-Alignment/align-anything/pull/153).
* **[2025.03.11]** 🎉🎉🎉 We support the installation of **Huawei Ascend** dependencies through pre-set Docker image.
* **[2025.03.02]** 🎉🎉🎉 We have implemented alignment training for Vision-Language-Action Models in embodied intelligence, see [VLA Trainer](https://github.com/PKU-Alignment/align-anything/tree/main/align_anything/trainers/text_video_to_action), with more features coming soon!
* **[2025.02.28]** 🤝🤝🤝 We supported the training and inference of align-anything on Huawei Ascend NPU.
Expand Down
5 changes: 1 addition & 4 deletions align_anything/architecture/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
# ==============================================================================
from abc import ABC, abstractmethod
from typing import List, Tuple, Any, Dict
from typing import Any, Dict, List, Tuple


class AbstractAgent(ABC):
Expand All @@ -34,7 +34,6 @@ def reset(self) -> None:
This method should be implemented by each subclass to reset the agent's internal state.
It is typically called at the beginning of each episode.
"""
pass

@abstractmethod
def get_action_list(self) -> List[str]:
Expand All @@ -47,7 +46,6 @@ def get_action_list(self) -> List[str]:
Returns:
List[str]: A list of action names.
"""
pass

@abstractmethod
def get_action(self, observations: Dict[str, Any], goal: str) -> Tuple[str, Any]:
Expand All @@ -64,4 +62,3 @@ def get_action(self, observations: Dict[str, Any], goal: str) -> Tuple[str, Any]
Returns:
Tuple[str, Any]: The chosen action and the action probabilities.
"""
pass
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright 2024 Allen Institute for AI

# Copyright 2024-2025 Align-Anything Team. All Rights Reserved.
Expand All @@ -20,7 +19,8 @@
from .early_fusion_gru_models import EarlyFusionCnnRNN
from .early_fusion_tsfm_models import EarlyFusionCnnTransformer


REGISTERED_MODELS = {
"EarlyFusionCnnTransformer": EarlyFusionCnnTransformer,
"EarlyFusionCnnRNN": EarlyFusionCnnRNN,
'EarlyFusionCnnTransformer': EarlyFusionCnnTransformer,
'EarlyFusionCnnRNN': EarlyFusionCnnRNN,
}

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright 2024 Allen Institute for AI

# Copyright 2024-2025 Align-Anything Team. All Rights Reserved.
Expand Down Expand Up @@ -29,9 +28,9 @@

@dataclass
class ClipResNetConfig:
model: str = "RN50"
model: str = 'RN50'
pool: bool = False
device: str = "cpu"
device: str = 'cpu'
output_size: Tuple[int, int, int] = (2048, 7, 12)


Expand Down Expand Up @@ -68,22 +67,22 @@ def stem(x):

@dataclass
class Dinov2Config:
model: str = "dinov2_vits14"
model: str = 'dinov2_vits14'
output_size: Tuple[int, int, int] = (384, 7, 12)


class Dinov2(nn.Module):
def __init__(self, cfg: Dinov2Config):
super().__init__()
self.cfg = cfg
self.model = torch.hub.load("facebookresearch/dinov2", cfg.model)
self.model = torch.hub.load('facebookresearch/dinov2', cfg.model)
self.pool = nn.AdaptiveAvgPool2d(cfg.output_size[1:])
self.eval()

def forward(self, x):
assert x.shape[-2:] == (224, 384), f"Expected shape is 224x384; got {x.shape}"
assert x.shape[-2:] == (224, 384), f'Expected shape is 224x384; got {x.shape}'
with torch.no_grad():
x = self.model.forward_features(x[:, :, :, 3:-3])["x_norm_patchtokens"]
x = self.model.forward_features(x[:, :, :, 3:-3])['x_norm_patchtokens']
B, _, D = x.shape # Bx432x384
x = x.permute(0, 2, 1) # Bx384x432
x = x.reshape(B, D, 16, 27)
Expand All @@ -93,22 +92,22 @@ def forward(self, x):

@dataclass
class SigLIPConfig:
model: str = "ViT-B-16-SigLIP-256"
model: str = 'ViT-B-16-SigLIP-256'
output_size: Tuple[int, int, int] = (768, 7, 12)


class SigLIP(nn.Module):
def __init__(self, cfg: Dinov2Config):
super().__init__()
self.cfg = cfg
siglip_full_model = create_model_from_pretrained("hf-hub:timm/{}".format(cfg.model))
siglip_full_model = create_model_from_pretrained(f'hf-hub:timm/{cfg.model}')
self.model = siglip_full_model[0].visual.trunk
self.context_length = siglip_full_model[0].context_length
self.pool = nn.AdaptiveAvgPool2d(cfg.output_size[1:])
self.eval()

def forward(self, x):
assert x.shape[-2:] == (256, 256), f"Expected shape is 256x256; got {x.shape}"
assert x.shape[-2:] == (256, 256), f'Expected shape is 256x256; got {x.shape}'
with torch.no_grad():
x = self.model.forward_features(x)
B, _, D = x.shape # Bx256x768
Expand All @@ -120,8 +119,8 @@ def forward(self, x):

IMAGE_ENCODERS = dict(
Dinov2Small=(Dinov2, Dinov2Config()),
Dinov2Base=(Dinov2, Dinov2Config(model="dinov2_vitb14", output_size=(768, 7, 12))),
Dinov2Base=(Dinov2, Dinov2Config(model='dinov2_vitb14', output_size=(768, 7, 12))),
ClipResNet50=(ClipResNet, ClipResNetConfig()),
SigLIPBase=(SigLIP, SigLIPConfig()),
SigLIPLarge=(SigLIP, SigLIPConfig(model="ViT-L-16-SigLIP-256", output_size=(1024, 7, 12))),
SigLIPLarge=(SigLIP, SigLIPConfig(model='ViT-L-16-SigLIP-256', output_size=(1024, 7, 12))),
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright 2024 Allen Institute for AI

# Copyright 2024-2025 Align-Anything Team. All Rights Reserved.
Expand Down Expand Up @@ -432,7 +431,7 @@ def forward(self, tokens: torch.Tensor, start_pos: int):

mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.full((seqlen, seqlen), float('-inf'), device=tokens.device)

mask = torch.triu(mask, diagonal=1)

Expand Down
Loading