Skip to content

[Research] Steering vectors #1472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,36 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.input_suffix = argv[i];
} else if (arg == "--steering-add") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.steering_add = argv[i];
} else if (arg == "--steering-sub") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.steering_sub = argv[i];
} else if (arg == "--steering-mul") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.steering_mul = std::stof(argv[i]);
} else if (arg == "--steering-source") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.steering_source = std::stoi(argv[i]);
} else if (arg == "--steering-layer") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.steering_layer = std::stoi(argv[i]);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, default_params);
Expand Down Expand Up @@ -423,6 +453,11 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
}
fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
fprintf(stderr, " number of layers to store in VRAM\n");
fprintf(stderr, " --steering-add add positive steering prompt\n");
fprintf(stderr, " --steering-sub add negative steering prompt\n");
fprintf(stderr, " --steering-mul steering strength (negative is reverse, default %.1f)\n", params.steering_mul);
fprintf(stderr, " --steering-source layer for steering source (default %d)\n", params.steering_source);
fprintf(stderr, " --steering-layer layer for steering insertion (default %d)\n", params.steering_layer);
fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
Expand Down
6 changes: 6 additions & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ struct gpt_params {
bool use_mlock = false; // use mlock to keep model in memory
bool mem_test = false; // compute maximum memory usage
bool verbose_prompt = false; // print prompt tokens before generation

std::string steering_add;
std::string steering_sub;
float steering_mul = 1.0f;
int steering_layer = 15;
int steering_source = 2;
};

bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
Expand Down
32 changes: 32 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,36 @@ int main(int argc, char ** argv) {
return 1;
}

if (!params.steering_add.empty() || !params.steering_sub.empty())
{
fprintf(stderr, "%s: steering: ('%s' - '%s') * %f\n",
__func__, params.steering_add.c_str(), params.steering_sub.c_str(), params.steering_mul);

params.steering_add.insert(0, 1, ' ');
params.steering_sub.insert(0, 1, ' ');

auto add_tokens = ::llama_tokenize(ctx, params.steering_add, true);
auto sub_tokens = ::llama_tokenize(ctx, params.steering_sub, true);


if (add_tokens.size() != sub_tokens.size()) {
while (add_tokens.size() < sub_tokens.size()) {
add_tokens.push_back(llama_token_nl());
}
while (sub_tokens.size() < add_tokens.size()) {
sub_tokens.push_back(llama_token_nl());
}
}

llama_set_steering_write(ctx, params.steering_source, +1.0f);
llama_eval(ctx, add_tokens.data(), std::min((int)add_tokens.size(), n_ctx), 0, params.n_threads);

llama_set_steering_write(ctx, params.steering_source, -1.0f);
llama_eval(ctx, sub_tokens.data(), std::min((int)sub_tokens.size(), n_ctx), 0, params.n_threads);

llama_set_steering_read(ctx, params.steering_layer, params.steering_mul);
}

// debug message about similarity of saved session, if applicable
size_t n_matching_session_tokens = 0;
if (session_tokens.size()) {
Expand Down Expand Up @@ -399,6 +429,8 @@ int main(int argc, char ** argv) {
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}

//llama_set_steering_off(ctx);

llama_token id = 0;

{
Expand Down
54 changes: 54 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <mutex>
#include <sstream>
#include <numeric>
#include <iostream>

#define LLAMA_USE_SCRATCH
#define LLAMA_MAX_SCRATCH_BUFFERS 16
Expand Down Expand Up @@ -229,6 +230,15 @@ struct llama_context {
// input embedding (1-dimensional array: [n_embd])
std::vector<float> embedding;

std::vector<float> steering_vector; // [n_ctx, n_embd]
int steering_layer = 0;
int steering_mode = 0;
float steering_mul = 0.0f;

#define STEERING_OFF 0
#define STEERING_WRITE 2
#define STEERING_READ 3

// memory buffers used to evaluate the model
// TODO: move in llama_state
llama_ctx_buffer buf_compute;
Expand Down Expand Up @@ -269,6 +279,24 @@ struct llama_context {
}
};

void llama_set_steering_off(struct llama_context * ctx) {
ctx->steering_mode = STEERING_OFF;
}

void llama_set_steering_write(struct llama_context * ctx, int layer, float mul) {
ctx->steering_mode = STEERING_WRITE;
ctx->steering_mul = mul;
ctx->steering_layer = layer;
}
void llama_set_steering_read(struct llama_context * ctx, int layer, float mul) {
ctx->steering_mode = STEERING_READ;
ctx->steering_mul = mul;
ctx->steering_layer = layer;
//FILE* steeringbin = fopen("steering.bin", "wb");
//fwrite(ctx->steering_vector.data(), sizeof(float), ctx->steering_vector.size(), steeringbin);
//fclose(steeringbin);
}

template <typename T>
static T checked_mul(T a, T b) {
T ret = a * b;
Expand Down Expand Up @@ -1152,6 +1180,13 @@ static bool llama_eval_internal(
ggml_set_name(embd, "embd");
memcpy(embd->data, tokens, N*ggml_element_size(embd));

struct ggml_tensor * steer;
if (lctx.steering_mode != STEERING_OFF) {
steer = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
//steer->data = lctx.steering_vector.data() + n_past * n_embd * sizeof(float);
memcpy(steer->data, lctx.steering_vector.data() + n_past * n_embd, ggml_nbytes(steer));
}

struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);

for (int il = 0; il < n_layer; ++il) {
Expand All @@ -1161,6 +1196,17 @@ static bool llama_eval_internal(

lctx.use_buf(ctx0, 0);

if (lctx.steering_mode != STEERING_OFF && il == lctx.steering_layer) {
struct ggml_tensor * scal = ggml_new_f32(ctx0, lctx.steering_mul);
if (lctx.steering_mode == STEERING_WRITE) {
ggml_build_forward_expand(&gf, ggml_cpy(ctx0,
ggml_add(ctx0, ggml_scale(ctx0, inpL, scal), steer), steer));
break;
}
// std::cout << "\nAdding steering vector to inpL " << il << "\n";
inpSA = ggml_add(ctx0, ggml_scale(ctx0, steer, scal), inpSA);
}

// norm
{
cur = ggml_rms_norm(ctx0, inpL);
Expand Down Expand Up @@ -1374,6 +1420,12 @@ static bool llama_eval_internal(
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
}


if (lctx.steering_mode == STEERING_WRITE) {
memcpy(lctx.steering_vector.data() + n_past * n_embd, steer->data, ggml_nbytes(steer));
}


if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
Expand Down Expand Up @@ -2195,6 +2247,8 @@ struct llama_context * llama_init_from_file(

ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0().at(ctx->model.type));
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));

ctx->steering_vector.resize(hparams.n_ctx * hparams.n_embd);
}

return ctx;
Expand Down
4 changes: 4 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ extern "C" {
LLAMA_API llama_token llama_token_eos();
LLAMA_API llama_token llama_token_nl();

LLAMA_API void llama_set_steering_off(struct llama_context * ctx);
LLAMA_API void llama_set_steering_write(struct llama_context * ctx, int layer, float mul);
LLAMA_API void llama_set_steering_read(struct llama_context * ctx, int layer, float mul);

// Sampling functions

/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
Expand Down