@@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
114114- [x] Unit Testing
115115- [ ] Policy Implementation
116116
117- | model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
118- | :------: | : -----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
119- | bert | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
120- | t5 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
121- | llama V1/V2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
122- | gpt2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
123- | opt | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
124- | bloom | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
125- | chatglm2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
126- | vit | [ √] | [ √] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
127- | whisper | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ √] | [ ] | [ ] |
128- | sam | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
129- | blip2 | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
130- | falcon | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ √] | [ ] | [ ] |
131- | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
132- | albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
133- | ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
134- | gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
135- | gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
136- | beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
137- | swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
138- | swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
139- | qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
140- | mistral | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
117+ | model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
118+ | :-----------:| :---------------: | :-----------------: | :-------------------: | :-------: | :-----------: | :------------------: | :---------------: | :-----------------: | :-------: |
119+ | bert | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
120+ | t5 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
121+ | llama V1/V2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
122+ | gpt2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
123+ | opt | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
124+ | bloom | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
125+ | chatglm2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
126+ | vit | [ √] | [ √] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
127+ | whisper | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ √] | [ ] | [ ] |
128+ | sam | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
129+ | blip2 | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
130+ | falcon | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ √] | [ ] | [ ] |
131+ | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
132+ | albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
133+ | ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
134+ | gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
135+ | gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
136+ | beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
137+ | swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
138+ | swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
139+ | qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
140+ | mistral | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
141141
142142
143143## 💡 API Design
@@ -391,6 +391,43 @@ _POLICY_LIST = {
391391}
392392```
393393
394+ # ### How to support those models in huggingface model hub but not in the transformers library
395+
396+ There are two cases:
397+
398+ 1 . the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of " 01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi- 34B is also supported by the llama policy. We do not need to add a new policy for Yi- 34B .
399+ 2 . the modeling file is not in the `transformers` library, such as the " THUDM/chatglm2-6b" .
400+
401+ Take " THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer` .
402+
403+ Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
404+
405+ E.g. for llama:
406+ ```python
407+ policy[LlamaDecoderLayer] = ModulePolicyDescription(... )
408+ ```
409+
410+ for chatglm2:
411+ ```python
412+ policy[" GLMBlock" ] = ModulePolicyDescription(... )
413+ ```
414+
415+ Then when registering such models in the autopolicy, we should follow below format :
416+ ```python
417+ " transformers_modules.<modeling_filename>.<class_name>" : PolicyLocation(
418+ file_name = " <policy_filename>" , class_name = " <policy_class_name>"
419+ )
420+ ```
421+
422+ As for chatglm2 model, it should be:
423+ ```python
424+ " transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration" : PolicyLocation(
425+ file_name = " chatglm2" , class_name = " ChatGLMForConditionalGenerationPolicy"
426+ )
427+ ```
428+
429+ When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
430+
394431# ## Write Your Unit Testing
395432
396433This section serves as the guideline for testing the `shardformer` module.
@@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
424461We set the batch size to 4 , the number of attention heads to 8 , and the head dimension to 64 . ' N_CTX' refers to the sequence length.
425462
426463In the case of using 2 GPUs, the training times are as follows.
427- | N_CTX | org_model | shard_model |
428- | :------ : | :---- - : | :---- - : |
429- | 256 | 11. 2ms | 17. 2ms |
430- | 512 | 9. 8ms | 19. 5ms |
431- | 1024 | 19. 6ms | 18. 9ms |
432- | 2048 | 46. 6ms | 30. 8ms |
433- | 4096 | 160. 5ms | 90. 4ms |
464+ | N_CTX | org_model | shard_model |
465+ | :---- - : | :-------- - : | :---------- - : |
466+ | 256 | 11. 2ms | 17. 2ms |
467+ | 512 | 9. 8ms | 19. 5ms |
468+ | 1024 | 19. 6ms | 18. 9ms |
469+ | 2048 | 46. 6ms | 30. 8ms |
470+ | 4096 | 160. 5ms | 90. 4ms |
434471
435472
436473< p align = " center" >
@@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.
440477
441478In the case of using 4 GPUs, the training times are as follows.
442479
443- | N_CTX | org_model | shard_model |
444- | :------ : | :---- - : | :---- - : |
445- | 256 | 10. 0ms | 21. 1ms |
446- | 512 | 11. 5ms | 20. 2ms |
447- | 1024 | 22. 1ms | 20. 6ms |
448- | 2048 | 46. 9ms | 24. 8ms |
449- | 4096 | 160. 4ms | 68. 0ms |
480+ | N_CTX | org_model | shard_model |
481+ | :---- - : | :-------- - : | :---------- - : |
482+ | 256 | 10. 0ms | 21. 1ms |
483+ | 512 | 11. 5ms | 20. 2ms |
484+ | 1024 | 22. 1ms | 20. 6ms |
485+ | 2048 | 46. 9ms | 24. 8ms |
486+ | 4096 | 160. 4ms | 68. 0ms |
450487
451488
452489
@@ -475,10 +512,10 @@ warmup_fraction = 0.03
475512
476513
477514| accuracy | f1 | loss | GPU number | model sharded |
478- | :------ : | :---- - : | :---- - : | :-------- : | :-------- - : |
479- | 0.82971 | 0.87713 | 0.23194 | 4 | True |
480- | 0.83797 | 0.88006 | 0.22683 | 2 | True |
481- | 0.84521 | 0.88700 | 0.21822 | 1 | False |
515+ | :-------- : | :------ - : | :------ - : | :---------- : | :------------ - : |
516+ | 0.82971 | 0.87713 | 0.23194 | 4 | True |
517+ | 0.83797 | 0.88006 | 0.22683 | 2 | True |
518+ | 0.84521 | 0.88700 | 0.21822 | 1 | False |
482519
483520
484521Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
0 commit comments