|
| 1 | +# Base class |
| 2 | + |
| 3 | +The base class BaseModel implements common methods for loading/saving models from a local file or directory or from a library-provided pretrained model configuration (downloaded from BAAI modelhub's Kingsoft S3 repository). |
| 4 | +All supported models now support the three most common model types [encoder, decoder and encoder-decoder]. GLM models can now load all GLM series models, see https://github.com/THUDM/GLM |
| 5 | + |
| 6 | +## From_pretrain |
| 7 | + |
| 8 | +Models with the same model structure can be loaded with the same class. For example, BERT-base and Roberta-base models can be loaded with the BertModel class. From_pretrain is optimized for data/model parallel model loading to avoid resource waste caused by repeated downloads. |
| 9 | +By calling ClassName.from_pretrian() to load, now our model hub supports the following models, you can directly download the model configuration file [config.json], model weights [pytorch_model.bin], and dictionary files [vocab .txt]. example: |
| 10 | + |
| 11 | +````python |
| 12 | +from flagai.model.glm_model import GLMForSingleTokenCloze |
| 13 | +model = GLMForSingleTokenCloze.from_pretrain(download_path="./state_dict", model_name="GLM-large-ch") |
| 14 | +```` |
| 15 | + |
| 16 | +If the model weights are loaded locally, they can also be loaded through ClassName.from_pretrain(). example: |
| 17 | +Load the model file `pytorch_model.bin` from the `./state_dict/GLM-large-ch` directory |
| 18 | + |
| 19 | +````python |
| 20 | +from flagai.model.glm_model import GLMForSingleTokenCloze |
| 21 | +model = GLMForSingleTokenCloze.from_pretrain(download_path="./state_dict", |
| 22 | + model_name="GLM-large-ch") |
| 23 | +```` |
| 24 | + |
| 25 | +## All supported models |
| 26 | + |
| 27 | +| ClassName | ModelName | Language | Model Type | |
| 28 | +|-----------------------------------|------------- ----|----------|------------| |
| 29 | +| flagai.model.glm_model.GLMModel | GLM-10b-ch | chinese | encoder | |
| 30 | +| flagai.model.glm_model.GLMModel | GLM-large-ch | chinese | encoder | |
| 31 | +| flagai.model.bert_model.BertModel | RoBERTa-base-ch | chinese | encoder | |
| 32 | +| flagai.model.gpt2_model.GPT2Model | GPT2_base_ch | chinese | decoder | |
| 33 | +| flagai.model.t5_model.T5Model | T5-base-ch | chinese | enc2dec | |
| 34 | +| flagai.model.t5_model.T5Model | T5-base-en | chinese | enc2dec | |
| 35 | +| flagai.model.bert_model.BertModel | BERT-base-en | english | encoder | |
| 36 | +| flagai.model.glm_model.GLMModel | GLM-large-en | english | encoder | |
| 37 | + |
| 38 | +## Supported models + tasks |
| 39 | + |
| 40 | +At the same time, we support the finetuned model on the task, as shown in the table below, the model weights can be loaded through ClassName.from_pretrain(), for example, we automatically download and load a GLM trained on the title-generation task -large-ch model: |
| 41 | + |
| 42 | +````python |
| 43 | +from flagai.model.glm_model import GLMForSeq2Seq |
| 44 | +model = GLMForSeq2Seq.from_pretrain(model_name='GLM-large-ch') |
| 45 | +```` |
| 46 | + |
| 47 | +We also provide the AutoLoader class to help load models. For example, the GLM-large-ch model is used for seq2seq tasks. Here we adopt a task- and model-independent design. In theory, tasks and models can be freely replaced. |
| 48 | + |
| 49 | +````python |
| 50 | +from flagai.auto_model.auto_loader import AutoLoader |
| 51 | +auto_loader = AutoLoader("seq2seq", |
| 52 | + model_name="GLM-large-ch", |
| 53 | + model_dir= "./state_dict") |
| 54 | +model = auto_loader.get_model() |
| 55 | +```` |
| 56 | + |
| 57 | +| ClassName | Model Name | language | Task | |
| 58 | +|------------------------------------------------- |-----------------|----------|-------------------| |
| 59 | +| flagai.model.glm_model.GLMForSeq2Seq | GLM-large-ch | chinese | title generation | |
| 60 | +| flagai.model.glm_model.GLMForSeq2Seq | GLM-large-ch | chinese | poetry generation | |
| 61 | +| flagai.model.bert_model.BertForSequenceLabeling | RoBERTa-base-ch | chinese | title generation | |
| 62 | +| flagai.model.bert_model.BertForSequenceLabeling | RoBERTa-base-ch | chinese | NER | |
| 63 | +| flagai.model.bert_model.BertForSequenceLabeling | RoBERTa-base-ch | chinese | semantic matching | |
| 64 | +| flagai.model.t5_model.T5Model | T5-base-ch | chinese | title generation | |
| 65 | +| flagai.model.bert_model.BertForSequenceLabeling | BERT-base-en | english | title gneration | |
| 66 | + |
| 67 | +## Model design |
| 68 | + |
| 69 | +The main construction logic of the model `layer->block>model` |
| 70 | +`flagai.model.layer`: including mlp, layernorm, activation, attention and other layers |
| 71 | + |
| 72 | +`flagai.model.block`: Build a transformer block by assembling various layers, such as BERT block, etc. |
| 73 | + |
| 74 | +`flagai.model`: build the model by embedding layers and stacked blocks |
| 75 | + |
| 76 | +## forward function |
| 77 | + |
| 78 | +Model's forward function: |
| 79 | +Input is keyword arguments: including input_ids, position_ids, attention_mask, etc., redundant parameters will be automatically ignored |
| 80 | +For example, GLM's forward function: |
| 81 | + |
| 82 | +````python |
| 83 | +def forward(self, |
| 84 | + input_ids=None, |
| 85 | + position_ids=None, |
| 86 | + attention_mask=None, |
| 87 | + mems=None, |
| 88 | + return_memory=False, |
| 89 | + detach_memory=True, |
| 90 | + prompt_pos=None, |
| 91 | + **kwargs) |
| 92 | +```` |
| 93 | + |
| 94 | +The output is a dictionary, including logits and hidden states, which are required, such as the return of the GLM forword function: |
| 95 | + |
| 96 | +````python |
| 97 | +return {'loss': loss, 'logits': logits, 'hidden_states': mems} |
| 98 | +```` |
| 99 | + |
| 100 | +## init_from_json |
| 101 | +Model's init_from json function: |
| 102 | +The input is a dictionary, the output is an initialized model |
| 103 | +For example, the invocation of GLMModel is as follows: |
| 104 | + |
| 105 | +````python |
| 106 | +GLMModel.init_from_json(config_file = "./config.json", **kwargs) |
| 107 | +```` |
| 108 | + |
| 109 | +**kwargs are reserved parameters, in order to be compatible with new initialization parameters of some models |
0 commit comments