Skip to content

Commit

Permalink
Merge branch 'main' of github.com:stanfordnlp/pyvene into main
Browse files Browse the repository at this point in the history
  • Loading branch information
PinetreePantry committed Mar 3, 2024
2 parents c63d976 + 35a3433 commit 431974c
Show file tree
Hide file tree
Showing 17 changed files with 4,748 additions and 534 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ The function solves a 3-digit sum problem. Let's say, we trained a neural networ

To translate the above steps into API calls with the library, it will be a single call,
```py
intervenable.evaluate(
intervenable.eval_alignment(
train_dataloader=test_dataloader,
compute_metrics=compute_metrics,
inputs_collator=inputs_collator
Expand Down Expand Up @@ -232,7 +232,7 @@ Instead of activation swapping in the original representation space, we first **

You can now also make a single API call to train your intervention,
```py
intervenable.train(
intervenable.train_alignment(
train_dataloader=train_dataloader,
compute_loss=compute_loss,
compute_metrics=compute_metrics,
Expand Down
20 changes: 20 additions & 0 deletions pyvene/models/gpt2/modelings_intervenable_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@

gpt2_lm_type_to_dimension_mapping = gpt2_type_to_dimension_mapping

"""gpt2 model with classifier head"""
gpt2_classifier_type_to_module_mapping = {}
for k, v in gpt2_type_to_module_mapping.items():
gpt2_classifier_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:]

gpt2_classifier_type_to_dimension_mapping = gpt2_type_to_dimension_mapping


def create_gpt2(name="gpt2", cache_dir=None):
"""Creates a GPT2 model, config, and tokenizer from the given name and revision"""
Expand All @@ -88,3 +95,16 @@ def create_gpt2_lm(name="gpt2", config=None, cache_dir=None):
gpt = GPT2LMHeadModel(config=config)
print("loaded model")
return config, tokenizer, gpt

def create_gpt2_classifier(name="gpt2", config=None, cache_dir=None):
"""Creates a GPT2ForSequenceClassification, config, and tokenizer from the given name and revision"""
from transformers import GPT2LMForSequenceClassification, GPT2Tokenizer, GPT2Config

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if config is None:
config = GPT2Config.from_pretrained(name)
gpt = GPT2LMForSequenceClassification.from_pretrained(name, config=config, cache_dir=cache_dir)
else:
gpt = GPT2LMForSequenceClassification(config=config)
print("loaded model")
return config, tokenizer, gpt
40 changes: 30 additions & 10 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,16 @@ def get_trainable_parameters(self):
if isinstance(v[0], TrainableIntervention):
ret_params += [p for p in v[0].parameters()]
return ret_params

def named_parameters(self, recurse=True):
"""
The above, but for HuggingFace.
"""
ret_params = []
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
ret_params += [(k + '.' + n, p) for n, p in v[0].named_parameters()]
return ret_params

def get_cached_activations(self):
"""
Expand Down Expand Up @@ -1247,7 +1257,7 @@ def forward(
unit_locations: Optional[Dict] = None,
source_representations: Optional[Dict] = None,
subspaces: Optional[List] = None,
output_original_output: Optional[bool] = None,
output_original_output: Optional[bool] = False,
return_dict: Optional[bool] = None,
):
"""
Expand Down Expand Up @@ -1340,9 +1350,11 @@ def forward(
activations_sources,
subspaces,
)

# returning un-intervened output with gradients
base_outputs = self.model(**base)

base_outputs = None
if output_original_output:
# returning un-intervened output with gradients
base_outputs = self.model(**base)

try:
# intervene
Expand Down Expand Up @@ -1416,6 +1428,7 @@ def generate(
source_representations: Optional[Dict] = None,
intervene_on_prompt: bool = False,
subspaces: Optional[List] = None,
output_original_output: Optional[bool] = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -1471,9 +1484,10 @@ def generate(
activations_sources,
subspaces,
)

# returning un-intervened output without gradients
with torch.inference_mode():

base_outputs = None
if output_original_output:
# returning un-intervened output
base_outputs = self.model.generate(inputs=base["input_ids"], **kwargs)

set_handlers_to_remove = None
Expand Down Expand Up @@ -1636,8 +1650,14 @@ def _batch_process_unit_location(self, inputs):
)

return batched_location_dict

def train(self):
self.model.train()

def eval(self):
self.model.eval()

def train(
def train_alignment(
self,
train_dataloader,
compute_loss,
Expand Down Expand Up @@ -1728,7 +1748,7 @@ def train(
self.set_temperature(temperature_schedule[total_step])
total_step += 1

def evaluate(
def eval_alignment(
self,
eval_dataloader,
compute_metrics,
Expand Down Expand Up @@ -1763,4 +1783,4 @@ def evaluate(
all_num_examples += [b_s]
result = weighted_average(all_metrics, all_num_examples)

return result
return result
2 changes: 2 additions & 0 deletions pyvene/models/intervenable_modelcard.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
type_to_module_mapping = {
hf_models.gpt2.modeling_gpt2.GPT2Model: gpt2_type_to_module_mapping,
hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel: gpt2_lm_type_to_module_mapping,
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_module_mapping,
Expand All @@ -58,6 +59,7 @@
type_to_dimension_mapping = {
hf_models.gpt2.modeling_gpt2.GPT2Model: gpt2_type_to_dimension_mapping,
hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel: gpt2_lm_type_to_dimension_mapping,
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_dimension_mapping,
Expand Down
4 changes: 2 additions & 2 deletions pyvene/models/llama/modelings_intervenable_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@


def create_llama(
name="sharpbai/alpaca-7b-merged", cache_dir=None
name="sharpbai/alpaca-7b-merged", cache_dir=None, dtype=torch.bfloat16
):
"""Creates a LLaMA Causal LM model, config, and tokenizer from the given name and revision"""
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
Expand All @@ -72,7 +72,7 @@ def create_llama(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=torch.bfloat16, # save memory
torch_dtype=dtype, # save memory
)
print("loaded model")
return config, tokenizer, llama
25 changes: 7 additions & 18 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@
"intervened_outputs = pv_gpt2(\n",
" base = tokenizer(\"The capital of Spain is\", return_tensors=\"pt\"), \n",
" # we define the intervening token dynamically\n",
" unit_locations={\"base\": 3}\n",
" unit_locations={\"base\": 3},\n",
" output_original_output=True # False then the first element in the tuple is None\n",
")"
]
},
Expand Down Expand Up @@ -1067,14 +1068,16 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"id": "f718e2d6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/sailhome/wuzhengx/.local/lib/python3.8/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" return self.fget.__get__(instance, owner)()\n",
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
Expand All @@ -1083,21 +1086,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"loaded model\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded model\n",
"Once upon a time there was a little girl named Lucy. She was three years old and loved to explore. One day, Lucy was walking in the park when she saw a big, red balloon. She was so excited and wanted to play with it.\n",
"\n",
"But then, a big, mean man came and said, \"That balloon is mine! You can't have it!\" Lucy was very sad and started to cry.\n",
Expand Down Expand Up @@ -1138,7 +1127,7 @@
"# prompt and generate\n",
"prompt = tokenizer(\n",
" \"Once upon a time there was\", return_tensors=\"pt\")\n",
"_, intervened_story = pv_tinystory.generate(\n",
"unintervened_story, intervened_story = pv_tinystory.generate(\n",
" prompt, source_representations=emb_happy*0.3, max_length=256\n",
")\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_clean_run_positive(self):
intervenable.set_device(self.device)
base = {"input_ids": torch.randint(0, 10, (10, 5)).to(self.device)}
golden_out = self.gpt2(**base).logits
our_output = intervenable(base)[0][0]
our_output = intervenable(base, output_original_output=True)[0][0]
self.assertTrue(torch.allclose(golden_out, our_output))
# make sure the toolkit also works
self.assertTrue(
Expand Down
4 changes: 3 additions & 1 deletion tests/integration_tests/InterventionWithMLPTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def test_clean_run_positive(self):
)
base = {"inputs_embeds": torch.rand(10, 1, 3)}
self.assertTrue(
torch.allclose(ONE_MLP_CLEAN_RUN(base, self.mlp), intervenable(base)[0][0])
torch.allclose(ONE_MLP_CLEAN_RUN(base, self.mlp), intervenable(
base, output_original_output=True)[0][0])
)

def test_with_subspace_positive(self):
Expand Down Expand Up @@ -255,6 +256,7 @@ def test_no_intervention_link_positive(self):
[source_1, source_2],
{"sources->base": ([[[0]] * b_s, [[0]] * b_s], [[[0]] * b_s, [[0]] * b_s])},
subspaces=[[[0]] * b_s, [[1]] * b_s],
output_original_output=True,
)

self.assertTrue(torch.allclose(golden_out_inplace, our_out_inplace[0]))
Expand Down
Loading

0 comments on commit 431974c

Please sign in to comment.