|  | 
| 1 |  | -# ✍🏻 Create a New Model (WIP, Adjusting according to the new architecture) | 
|  | 1 | +# ✍🏻 Create a New Model | 
|  | 2 | + | 
|  | 3 | + | 
|  | 4 | +This guide shows you, step by step, how to plug a **new end‑to‑end policy model** into the InternManip framework. Follow the checklist below and you will be able to train your custom model with the stock training script (`scripts/train/train.py`)—no core code edits required. | 
|  | 5 | + | 
| 2 | 6 | 
 | 
| 3 |  | -This section guides you through the process of adding a new end2end model to the InternManip framework. | 
| 4 | 7 | 
 | 
| 5 | 8 | ## File Structure and Why | 
| 6 | 9 | 
 | 
| 7 |  | -Currently, the leading manipulation models try to leverage the existing pretrained large models for better generalization. They (for instance, **GR00T-N1** and **Pi-0**) often consist of a pretrained VLM backbone and a small downstream action expert that maps extracted hidden context to action space. In this way, we organize the model files into three main components: | 
| 8 |  | -- **Backbone**: The pretrained VLM backbone, which is responsible for understanding the visual and textual inputs. | 
| 9 |  | -- **Action Head**: The downstream action expert that takes the context from the backbone and maps it to the action space. | 
| 10 |  | -- **Policy Model**: The base model that integrates the backbone and action head into a single end-to-end model. | 
|  | 10 | +Currently, leading manipulation models strive to leverage existing pretrained large models for better generalization. For example, **GR00T-N1** and **Pi-0** typically consist of a pretrained VLM backbone and a compact downstream action expert that maps extracted context to the action space. Reflecting this design, InternManip organizes model files into three main components: | 
| 11 | 11 | 
 | 
| 12 |  | -Specifically, the model definitions are located in the `internmanip/model` directory, there are three subfolders under this directory: | 
| 13 |  | -```plaintext | 
|  | 12 | +- **Backbone**: The pretrained VLM backbone responsible for understanding visual or textual inputs. | 
|  | 13 | +- **Action Head**: The downstream expert that consumes backbone features and predicts actions. | 
|  | 14 | +- **Policy Model**: The wrapper that integrates the backbone and action head into a single end-to-end policy. | 
|  | 15 | + | 
|  | 16 | +Model definitions reside in the `internmanip/model` directory, which contains three sub-folders: | 
|  | 17 | + | 
|  | 18 | +```text | 
| 14 | 19 | internmanip | 
| 15 | 20 | ├── model | 
| 16 |  | -│   ├── action_head | 
| 17 |  | -│   ├── backbone | 
| 18 |  | -│   ├── basemodel | 
| 19 |  | -│   │   ├── base.py | 
| 20 |  | -│   │   ├── ... | 
| 21 |  | -│   ├── ... | 
| 22 |  | -├── ... | 
|  | 21 | +│   ├── action_head        # task‑specific experts | 
|  | 22 | +│   ├── backbone           # pretrained encoders (ViT, CLIP, …) | 
|  | 23 | +│   └── basemodel          # full end‑to‑end policies | 
|  | 24 | +│       └── base.py        # <‑‑ universal interface | 
|  | 25 | +... | 
|  | 26 | +└── configs | 
|  | 27 | +    └── model              # config classes (inherits PretrainedConfig) | 
|  | 28 | +scripts | 
|  | 29 | +    └── train              # trainers, entry points | 
| 23 | 30 | ``` | 
| 24 | 31 | 
 | 
| 25 |  | -To create a new model, you need to implement a new model class derived from the `BasePolicyModel` class in `internmanip/model/basemodel/base.py`. It looks like this: | 
| 26 |  | -```python | 
| 27 |  | -from transformers import PreTrainedModel | 
|  | 32 | +## 1. Outline | 
|  | 33 | +To integrate a new model into the framework, you need to create the following files: | 
| 28 | 34 | 
 | 
| 29 |  | -from internmanip.configs.model.model_cfg import ModelCfg | 
|  | 35 | +1. A **Config** that stores architecture related hyper‑parameters. | 
|  | 36 | +2. A **Model** class that inherits `BasePolicyModel` and implements the model structure. | 
|  | 37 | +3. A **data\_collator** that shapes raw samples into model‑ready tensors. | 
| 30 | 38 | 
 | 
| 31 |  | -class BasePolicyModel(PreTrainedModel): | 
| 32 |  | -    policy_models = {} | 
|  | 39 | +Finally, you need to **register** the model with the framework and you can start training your model. We will guide you through the process step by step. | 
| 33 | 40 | 
 | 
| 34 |  | -    def __init__(self, config: ModelCfg): | 
| 35 |  | -        super().__init__(config) | 
| 36 |  | -        self.config = config | 
| 37 | 41 | 
 | 
| 38 |  | -    def forward(self, *args, **kwargs): | 
| 39 |  | -        raise NotImplementedError("Forward method not implemented.") | 
|  | 42 | +## 2. Create the Model Configuration File | 
|  | 43 | + | 
|  | 44 | +The config file is used to store the architecture related hyper-parameters. Here is some basic information you need to know: | 
|  | 45 | +You shall add the model configuration file in `internmanip/configs/model/{model_name}_cfg.py`, which should inherit `transformers.PretrainedConfig`. | 
| 40 | 46 | 
 | 
| 41 |  | -    def inference(self, *args, **kwargs): | 
| 42 |  | -        raise NotImplementedError("inference method not implemented.") | 
|  | 47 | +The following is **an example** of a model configuration file: | 
| 43 | 48 | 
 | 
|  | 49 | +```python | 
|  | 50 | +from transformers import PretrainedConfig | 
|  | 51 | + | 
|  | 52 | +class CustomPolicyConfig(PretrainedConfig): | 
|  | 53 | +    """Configuration for CustomPolicy.""" | 
|  | 54 | +    model_type = "custom_model" | 
|  | 55 | + | 
|  | 56 | +    def __init__(self, | 
|  | 57 | +                 vit_name="google/vit-base-patch16-224-in21k", | 
|  | 58 | +                 freeze_vit=True, | 
|  | 59 | +                 hidden_dim=256, | 
|  | 60 | +                 output_dim=8, | 
|  | 61 | +                 dropout=0.0, | 
|  | 62 | +                 n_obs_steps=1, | 
|  | 63 | +                 horizon=10, | 
|  | 64 | +                 **kwargs): | 
|  | 65 | +        super().__init__(**kwargs) | 
|  | 66 | +        self.vit_name = vit_name | 
|  | 67 | +        self.freeze_vit = freeze_vit | 
|  | 68 | +        self.hidden_dim = hidden_dim | 
|  | 69 | +        self.output_dim = output_dim | 
|  | 70 | +        self.dropout = dropout | 
|  | 71 | +        self.n_obs_steps = n_obs_steps | 
|  | 72 | +        self.horizon = horizon | 
|  | 73 | + | 
|  | 74 | +    def transform(self) -> Tuple[List[Transform], List[int], List[int]]: | 
|  | 75 | +        transforms = None | 
|  | 76 | +        return transforms, list(range(self.n_obs_steps)), list(range(self.horizon)) | 
| 44 | 77 | ``` | 
| 45 |  | -where you need to implement the `__init__`, `forward`, and `inference` methods. The `forward` method is used for training, while the `inference` method is used for inference. | 
| 46 | 78 | 
 | 
| 47 |  | -## Implementation Steps | 
| 48 |  | -As a quick start, we will use a very simple model with a ViT visual encoder and two layers of MLP as an example. | 
|  | 79 | +As shown in the example above, the config class defines key architectural hyperparameters—such as the backbone model name, whether to freeze the backbone, the hidden/output dimensions of the action head, and more. You are free to extend this config with any additional parameters required by your custom model. | 
| 49 | 80 | 
 | 
| 50 |  | -1. Create a new file for your model in the `internmanip/model/basemodel` directory, for example `custom_model.py`. | 
| 51 |  | -2. Import the necessary modules and classes, implement `__init__`, `forward`, and `inference` methods, and register your model class with the `BasePolicyModel` class: | 
| 52 |  | -```python | 
| 53 |  | -from pydantic import BaseModel | 
| 54 |  | -from typing import Dict, Any, Optional | 
| 55 |  | -import torch | 
| 56 |  | -import torch.nn as nn | 
| 57 |  | -import torch.nn.functional as F | 
| 58 |  | -from transformers import ViTModel, ViTConfig   # pip install transformers | 
|  | 81 | +Additionally, you can implement a **model-specific `transform` method** within the config class. This method allows you to apply custom data transformations that are *not* included in the dataset-specific transform list defined in `internmanip/configs/dataset/data_config.py`. | 
| 59 | 82 | 
 | 
|  | 83 | +During training, the script `scripts/train/train.py` will automatically call this method and apply your custom transform alongside the default ones. Your `transform` method should follow the same input/output format as dataset-specific transform. For implementation guidance, refer to examples in the `internmanip/dataset/transform` directory. | 
| 60 | 84 | 
 | 
| 61 |  | -from internmanip.model.basemodel.base import BasePolicyModel | 
| 62 | 85 | 
 | 
|  | 86 | +## 3. Implement the Model | 
|  | 87 | + | 
|  | 88 | +In this class to implement the model, you need to inherit `BasePolicyModel` and register it with `@BasePolicyModel.register("custom_model")`. | 
| 63 | 89 | 
 | 
| 64 |  | -class CustomPolicyConfig(BaseModel): | 
| 65 |  | -    """Configuration for Custom Policy Model.""" | 
| 66 |  | -    vit_name: str = "google/vit-base-patch16-224-in21k"  # or any HF ViT | 
| 67 |  | -    freeze_vit: bool = True | 
| 68 |  | -    input_dim: int | 
| 69 |  | -    hidden_dim: int = 256 | 
| 70 |  | -    output_dim: int | 
| 71 |  | -    dropout: float = 0.0 | 
|  | 90 | +The model configuration file will be passed to the `__init__` method of the model class to initialize the model. With in the `__init__` method, you should define the model structure and initialize the model. | 
| 72 | 91 | 
 | 
|  | 92 | +You should also implement the `forward` method to define the model forward pass. The `forward` method should return a dictionary of tensors, which will be used to compute the loss. The `inference` method is used to generate the action from the model. | 
|  | 93 | + | 
|  | 94 | +```python | 
|  | 95 | +from internmanip.model.basemodel.base import BasePolicyModel | 
|  | 96 | +from transformers import ViTModel, ViTConfig | 
|  | 97 | +import torch.nn as nn, torch.nn.functional as F, torch | 
|  | 98 | +from typing import Dict | 
|  | 99 | +from internmanip.configs.model.custom_policy_cfg import CustomPolicyConfig | 
| 73 | 100 | 
 | 
| 74 | 101 | @BasePolicyModel.register("custom_model") | 
| 75 |  | -class CustomModel(BasePolicyModel): | 
| 76 |  | -    """Two-layer MLP policy.""" | 
|  | 102 | +class CustomPolicyModel(BasePolicyModel): | 
|  | 103 | +    """ViT backbone + 2‑layer MLP head.""" | 
| 77 | 104 | 
 | 
| 78 | 105 |     def __init__(self, config: CustomPolicyConfig): | 
| 79 |  | -        super().__init__() | 
|  | 106 | +        super().__init__(config) | 
| 80 | 107 |         self.config = config | 
|  | 108 | +        name = "custom_model" | 
| 81 | 109 | 
 | 
| 82 |  | -        # 1. ViT visual encoder | 
|  | 110 | +        # 1 Backbone | 
| 83 | 111 |         vit_conf = ViTConfig.from_pretrained(config.vit_name) | 
| 84 | 112 |         self.vit = ViTModel.from_pretrained(config.vit_name, config=vit_conf) | 
| 85 | 113 |         if config.freeze_vit: | 
| 86 | 114 |             for p in self.vit.parameters(): | 
| 87 | 115 |                 p.requires_grad = False | 
| 88 | 116 | 
 | 
| 89 |  | -        # 2. Two-layer MLP head | 
| 90 |  | -        vit_out_dim = vit_conf.hidden_size   # 768 for base | 
|  | 117 | +        # 2 Action Head | 
| 91 | 118 |         self.mlp = nn.Sequential( | 
| 92 |  | -            nn.Linear(vit_out_dim, config.hidden_dim), | 
|  | 119 | +            nn.Linear(vit_conf.hidden_size, config.hidden_dim), | 
| 93 | 120 |             nn.ReLU(), | 
| 94 | 121 |             nn.Dropout(config.dropout), | 
| 95 | 122 |             nn.Linear(config.hidden_dim, config.output_dim), | 
| 96 | 123 |         ) | 
| 97 | 124 | 
 | 
|  | 125 | +    # —— Training / Inference —— | 
|  | 126 | +    def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: | 
|  | 127 | +        imgs, tgt = batch["images"], batch.get("actions") | 
|  | 128 | +        feats = self.vit(imgs).last_hidden_state[:, 0]  # CLS token | 
|  | 129 | +        pred = self.mlp(feats) | 
|  | 130 | +        out = {"prediction": pred} | 
|  | 131 | +        if train and tgt is not None: | 
|  | 132 | +            out["loss"] = F.mse_loss(pred, tgt.view_as(pred)) | 
|  | 133 | +        return out | 
|  | 134 | + | 
|  | 135 | +    def inference(self, batch: dict[str, Tensor], **kwargs) -> Tensor: | 
|  | 136 | +        actions = self.forward(batch, noise=None, time=None)["prediction"] | 
|  | 137 | +        return actions | 
|  | 138 | +``` | 
| 98 | 139 | 
 | 
| 99 |  | -    def forward(self, batch: Dict[str, torch.Tensor], train: bool = True, **kwargs) -> Dict[str, torch.Tensor]: | 
| 100 |  | -        """ | 
| 101 |  | -        Unified forward pass for both training and inference. | 
| 102 |  | -        When train=True we also return the loss. | 
| 103 |  | -        """ | 
| 104 |  | -        images = batch["images"]         # (B, 3, 224, 224) | 
| 105 |  | -        vit_out = self.vit(images).last_hidden_state[:, 0] # (B, 768) - CLS token output | 
| 106 |  | -        pred = self.mlp(vit_out) | 
|  | 140 | +In the example above, the model is composed of a ViT backbone and a simple 2-layer MLP action head. The `forward` method handles loss computation during training, while the `inference` method generates actions during evaluation. | 
| 107 | 141 | 
 | 
| 108 |  | -        outputs = {"prediction": pred} | 
|  | 142 | +When designing your own model, you can follow this backbone–head pattern or adopt a completely different architecture. If needed, you can define custom `backbone` and `action_head` modules—typically by subclassing `nn.Module`. Just ensure that your model's `inference` output has the shape `(n_actions, action_dim)`. | 
| 109 | 143 | 
 | 
| 110 |  | -        if train: | 
| 111 |  | -            # Assume the batch contains a key named "actions" that holds the GT | 
| 112 |  | -            if pred.shape != targets.shape: | 
| 113 |  | -                targets = targets.view_as(pred) | 
| 114 |  | -            loss = F.mse_loss(pred, targets) | 
| 115 |  | -            outputs["loss"] = loss | 
| 116 | 144 | 
 | 
| 117 |  | -        return outputs | 
|  | 145 | +## 4. Write a Data Collator | 
| 118 | 146 | 
 | 
| 119 |  | -    def inference(self, batch: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]: | 
| 120 |  | -        """Inference-specific forward pass (no loss).""" | 
| 121 |  | -        return self.forward(batch, train=False, **kwargs) | 
|  | 147 | +You need to define a data_collator function that converts a list of raw samples from default data loader into a single batch dictionary that is compatible with the model's `forward` method. | 
| 122 | 148 | 
 | 
|  | 149 | +```python | 
|  | 150 | +import torch | 
|  | 151 | +import torch.nn as nn | 
|  | 152 | +import torch.nn.functional as F | 
|  | 153 | + | 
|  | 154 | +@DataCollatorRegistry.register("custom_model") | 
|  | 155 | +def custom_data_collator(samples): | 
|  | 156 | +    imgs = torch.stack([s["image"] for s in samples]) | 
|  | 157 | +    acts = torch.stack([s["action"] for s in samples]) | 
|  | 158 | +    return {"images": imgs, "actions": acts} | 
|  | 159 | +``` | 
|  | 160 | + | 
|  | 161 | +> **Why?** The built‑in `BaseTrainer` accepts any callable named `data_collator` so long as it returns a dictionary of tensors compatible with your model’s `forward` signature. | 
|  | 162 | +
 | 
|  | 163 | + | 
|  | 164 | +## 5. Register Everything | 
|  | 165 | + | 
|  | 166 | +Add the following **one-time** registration lines (typically at the end of your model file) to enable seamless dynamic loading with `AutoConfig` and `AutoModel`: | 
|  | 167 | + | 
|  | 168 | +```python | 
|  | 169 | +from transformers import AutoConfig, AutoModel | 
|  | 170 | + | 
|  | 171 | +AutoConfig.register("custom_model", CustomPolicyConfig) | 
|  | 172 | +AutoModel.register(CustomPolicyConfig, CustomPolicyModel) | 
| 123 | 173 | ``` | 
| 124 |  | -3. Now you can train your just customized model on `genmanip-demo` dataset with the following command: | 
| 125 |  | -```bash | 
| 126 |  | -torchrun --nnodes 1 --nproc_per_node 1 \       # number of processes per node, e.g., 1 | 
| 127 |  | -   scripts/train/train.py \ | 
| 128 |  | -   --model_name custom_model \     # model name | 
| 129 |  | -   --dataset-path genmanip-demo \  # registered dataset name or custom path | 
| 130 |  | -   --data-config genmanip-v1       # registered data config | 
|  | 174 | + | 
|  | 175 | +Make sure the string `"custom_model"` passed to `AutoConfig.register` matches the model name used in both your `CustomPolicyModel` definition and the data collator registration. | 
|  | 176 | + | 
|  | 177 | +Don't forget to register the module in your __init__.py, so that your custom model gets imported and initialized properly during runtime. For example: | 
|  | 178 | + | 
|  | 179 | +```python | 
|  | 180 | +# In internmanip/model/basemodel/__init__.py | 
|  | 181 | +from internmanip.model.basemodel.base import BasePolicyModel | 
|  | 182 | + | 
|  | 183 | +__all__ = ["BasePolicyModel"] | 
|  | 184 | +# Import all model modules to ensure registration logic is executed | 
|  | 185 | +from internmanip.model.basemodel.custom import custom_model  # <- Your custom model module | 
| 131 | 186 | ``` | 
| 132 | 187 | 
 | 
| 133 |  | -For more advanced tutorials, please refer to the [Model](../tutorials/model.md) section. | 
|  | 188 | +Once registered, InternManip’s trainer can instantiate your model and you can start training. | 
|  | 189 | + | 
|  | 190 | +📚 For more details related to training and evaluation, please refer to [train_eval.md](./train_eval.md) and [training.md](../tutorials/training.md). | 
0 commit comments