Skip to content

MSNetrom/in-context-poly-playground

 
 

Repository files navigation

Code for the paper "Polynomial Regression as a Task for Understanding In-context Learning Through Finetuning and Alignment"

Polynomial Regression Reproduce Instructions

A work in progress. See Quickstart for conda setup.

For training base-model, lora, and soft-prompting, run: python3 src/train_finetuning.py. Alternatively, download checkpoints from here.

For evaluating results use the eval_finetuning.ipynb - notebook.

For visualizing Chebyshev-Linear-Regression-functionclass, use the visualize_chebyshev_functionclass.ipynb - notebook.

Quickstart

Set up your environment with:

conda init zsh
conda env create -f environment.yaml
conda activate in-context-learning

Alternatively, you can instantiate a codespace (docs) that automagically does the above!

Run a training run specified by <config_file> with:

python src/ --config conf/train/<config_file>.yml

Resume a training run specified by <config_file> starting from <run_id>/<checkpoint_file> with:

python src/ --config conf/train/<config_file>.yml --resume models/<run_id>/<checkpoint_file>

Extending this repo

We follow the following flow for objects in this repository: image of the flow map

Initializing a FunctionClass or ContextModel passes the dictionary parsed directly from yaml as arguments. We recommend reading examples for your respective function class in conf/

Make sure to initialize your parent class with super().__init__(...) in your __init__ routine for your custom FunctionClass or ContextModel. Omitting this can result in unexpected (and currently unchecked) errors!

We try to make it easy to contribute by adding models or function classes:


Adding Models

We wrap all sequence architectures and baselines as ContextModels. The order of calls against a ContextModel is as follows:

  • We __init__ the ContextModel directly from the yaml configuration
  • Inside the training loop and during evaluation, we call evaluate to generate predictions

ContextModels thus have only two required methods:

  • __init__(...)
  • evaluate(self, xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor

In addition, we require both .context_length: int and .name: str to be set for evaluation and visualization.

To use your context model, add an identifier to src/models/__init__.py and refer to it as type: <your identifier> in yaml. Examples are in conf/include/models/

Adding Hybrid blocks:

We also provide an abstraction to compose hybrid architectures as a collection of blocks. The currently supported blocks are in src/models/hybrid.py. To add a block, add an identifier to SUPPORTED_BLOCKS in src/models/hybrid.py and an entry to MAPPING in the function SPEC_TO_MODULE in the same file. To pass a value to the configuration fed into your new block, just specify it as a key/value pair in your model: mapping in YAML. Examples of this are in conf/include/models/composed.yml, under defs -> base.


Adding Function Classes

We take the approach of framing each function class as an iterator, where the boilerplate of making the iterator work for our training scheme (ContextTrainer) is implemented for you. The order of calls against a FunctionClass is as follows:

  • We __init__ the FunctionClass with the parsed yaml configuration
  • At the start of the training loop and/or evaluation, we produce an iterator with __iter__ to sample batches from
  • During the training loop and/or evaluation, we sample x,y batches with __next__
    • this calls evaluate to deterministically turn a sampled batch of x values and a sampled batch of parameters into a prediction

FunctionClasses thus have three methods you need to implement:

  • __init__(...)
  • _init_param_dist(self)
  • evaluate(self, x_batch: torch.Tensor, *params: torch.Tensor) -> torch.Tensor

In addition, you need to export your function class in src/function_classes/__init__.py

To use your function class, you can specify the key you specified in src/function_classes/__init__.py as type: <that key> in a training yaml configuration. Examples are in conf/train/

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 90.9%
  • Python 9.1%