Skip to content

Direct Accces to Microsoft.ML.GenAI.LLaMA Model #7367

Open
@aforoughi1

Description

@aforoughi1

I would like to convert a LLama model to a Multiclass Classification Model and then finetune the model on my classification labels.

Currently, the Microsoft.ML.GenAI.LLaMA/Module/LlamaModel is internal.

Step 1
Load the pre-trained LLaMA model

string device = "cpu";
string weightFolder = @".\Llama3.1-8B";
string originalWeightFolder = Path.Combine(weightFolder, "original");
string configName = "config.json";
string modelFile = "tokenizer.model";
string checkPointName = "model.safetensors.index.json";

// Load the Pretrained Model: First, load the pretrained LLaMA model using TorchSharp.
model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false, targetDevice: device);

Step 2
Create a classification head

public class ClassificationHead : Module<Tensor, Tensor>
{
    private readonly Module<Tensor, Tensor> linear1;
    private readonly Module<Tensor, Tensor> relu;
    private readonly Module<Tensor, Tensor> linear2;

    public ClassificationHead(int d_model,int outputSize, int num_classes) : base(nameof(ClassificationHead))
    {
        linear1 = Linear(d_model, outputSize);     // Intermediate layer
        relu = ReLU();                      // Activation
        linear2 = Linear(outputSize, num_classes); // Output layer

        RegisterComponents();
    }

    public override Tensor forward(Tensor x)
    {
        var output = linear1.forward(x);
        output = relu.forward(output);
        output = linear2.forward(output);
        return output;
    }
}

step 3
Integrate the classification head into the LLaMA model

This step is not possible to override the Forward Pass: The input is passed through the LLaMA model and then through the classification head to get the output logits.

step 4
Set up the training loop to optimize the model using my data

I intend to use Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer

Metadata

Metadata

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions