This project uses REFT (Representation Fine-Tuning) to improve the Polish language capabilities of Meta's Llama-3.1-8B-Instruct model.
REFT (Representation Fine-Tuning) is a parameter-efficient fine-tuning method that applies interventions at specific layers of a large language model to modify its behavior. Unlike traditional fine-tuning or PEFT (Parameter-Efficient Fine-Tuning) methods like LoRA that operate on weights across all timesteps, REFT:
- Selects specific timesteps to intervene on
- Targets representations rather than weights
- Requires very few trainable parameters
- Can be trained with a small number of examples
This makes REFT ideal for enhancing language models for specific languages or tasks with minimal computational resources.
prepare_polish_data.py
: Creates a dataset of Polish instruction-response pairstrain_polish_reft.py
: Trains Llama-3.1-8B-Instruct with REFT for Polish language improvementevaluate_polish_reft.py
: Compares the original and REFT-enhanced models on Polish taskspolish_data/
: Directory containing the Polish training examplespolish_reft_model/
: Directory containing the saved REFT model after trainingevaluation_results/
: Directory containing evaluation results
-
Data Preparation: We create a dataset of diverse Polish instruction-response pairs covering grammar, vocabulary, cultural content, and more.
-
REFT Configuration: We configure two REFT interventions at different layers:
- Layer 15: For improving Polish language understanding
- Layer 25: For improving Polish language generation
-
Training: We train the REFT model on our Polish examples, which requires only a fraction of the parameters of the full model.
-
Evaluation: We compare the original and REFT-enhanced models on various Polish language tasks.
- Python 3.8+
- PyTorch
- Transformers
- REFT (pyreft)
- CUDA-capable GPU (recommended)
Install the required packages:
pip install pyreft transformers torch huggingface-hub datasets accelerate sentencepiece
python prepare_polish_data.py
This creates a dataset of Polish examples in polish_data/polish_examples.json
.
python train_polish_reft.py
This trains the REFT model and saves it to polish_reft_model/
.
python evaluate_polish_reft.py
This compares the original and REFT-enhanced models on Polish tasks and saves the results to evaluation_results/comparison_results.json
.
The REFT-enhanced model can be loaded and used for Polish language tasks as follows:
import torch
import transformers
import pyreft
# Load the base model
model = transformers.AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
device_map="auto"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
model_max_length=2048,
padding_side="right"
)
# Create REFT configuration (must match the training configuration)
reft_config = pyreft.ReftConfig(
representations=[
{
"layer": 15,
"component": "block_output",
"low_rank_dimension": 8,
"intervention": pyreft.LoreftIntervention(
embed_dim=model.config.hidden_size,
low_rank_dimension=8
)
},
{
"layer": 25,
"component": "block_output",
"low_rank_dimension": 8,
"intervention": pyreft.LoreftIntervention(
embed_dim=model.config.hidden_size,
low_rank_dimension=8
)
}
]
)
# Load the REFT model
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.load("./polish_reft_model")
# Use the model
prompt = "Jak opisać polską kulturę w kilku zdaniach?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
base_unit_location = inputs["input_ids"].shape[-1] - 1
_, reft_response = reft_model.generate(
inputs,
unit_locations={"sources->base": (None, [[[base_unit_location]]])},
intervene_on_prompt=True,
max_new_tokens=512,
do_sample=True,
temperature=0.7
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))
- Resource Efficiency: Trains with minimal parameters (~32K vs billions in the full model)
- Data Efficiency: Requires only a small number of examples (10-20 can be sufficient)
- Targeted Improvement: Specifically enhances Polish language understanding and generation
- Preservation of Original Capabilities: Maintains the model's capabilities in other languages
- The enhancements are focused on Polish language capabilities only
- Performance depends on the quality and diversity of the training examples
- The REFT approach might not be as comprehensive as full fine-tuning with large datasets
This project uses the same license as the base Llama 3.1 model. Please refer to Meta's licensing terms for the Llama model.