Skip to content

API Reference

Ruoxi edited this page May 7, 2024 · 18 revisions

Training Loop

rst.DefaultLoop(
    model: torch.nn.Module,
    task: Task,
    loss: Optional[Loss] = None,
    metrics: Optional[Sequence[Union[Metric, EllipsisType]]] = None,
    processors: Sequence[Processor] = None,
    optimizer: Union[str, torch.optim.Optimizer] = 'adam',
    adapter: Adapter = Adapter(),
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
    scheduler_base: Literal['epoch', 'step'] = 'step',
    *,
    batch_size=32, num_workers=0, amp=False
)

This is a central class in torch.redstone, the default training loop. In most cases, it would be sufficient to use this class to train and test models. If you need to do more customization, you may also derive a new class that inherits this DefaultLoop.

Parameters:

  • model: the target model in operation.
  • task: the target task. Please see the interfaces section for details about tasks.
  • loss: the loss for training. If set to None (the default), the DefaultLoss is applied, which finds the Metric whose name is loss (case insensitive). Please see the interfaces section for details about losses.
  • metrics: the metrics to record and monitor. If set to None (the default), the metrics defined by the Task will be applied. Otherwise, it should be a Metric list. The list may also contain a ... which stands for the metrics defined by the Task. Default displayed name for a metric is its class name. This can be overridden by setting metric.name. In returned ObjectProxy for metrics, attribute names are lower-case displayed names. Please see the interfaces section for details about metrics.
  • processors: the list of processors. If set to None (the default), a default Logger will be applied. Please see the interfaces section for details about processors.
  • optimizer: a torch.optim.Optimizer instance or a str: sgd, adam, rmsprop or adadelta. The sgd is the SGD optimizer with learning rate 0.01 and nesterov momentum 0.9. Other strings refer to a default parameters (suggested in relevant papers) of that optimizer.
  • adapter: an adapter intended to offer a chance to unify input/output interfaces of different models (maybe from different sources) without having to make a wrapper around them. Please see the interfaces section for details about adapters.
  • scheduler: a learning rate scheduler. An example is the cyclic learning rate (CLR) scheduler: torch.optim.lr_scheduler.CyclicLR(optimizer, 0.0, 0.2).
  • scheduler_base: the base step of the learning rate scheduler. According to the value (epoch or step), the learning rate schedule will be advanced each epoch or each iteration step, respectively.
  • batch_size: the batch size used to create the data loaders. Ignored if the Task returns data loaders already. Keyword-only argument. Defaults to 32.
  • num_workers: the number of worker processes of the data loaders. Ignored if the Task returns data loaders already. Keyword-only argument. Defaults to 0 (data loaders run in the main process).
  • amp: enable or disable automatic mixed precision. Various operations would be converted into FP16 operations instead of the default FP32 to save memory, and also time if you have a moderately new CUDA GPU with tensor cores. In a few cases, e.g., empirically, precision tasks like image super resolution, you may see a significant decrease in model performance metrics with amp. Otherwise, it would be great to enable amp. Keyword-only argument. Defaults to False. It is recommended to specify torch.bfloat16 or torch.float16 explicitly in this argument on supported GPUs.

Methods:

run(self, num_epochs, train=True, val=True, max_steps=None, quiet=False)

Runs the loop for num_epochs epochs.

You may set to skip training or valiation by setting train and val to False, respectively.

You may limit the maximum steps (batches) to run by setting max_steps.

A progress bar will be displayed along with average statistics if quiet is False.

create_data_loader(self, data: Union[Dataset, list], is_train: bool)

Called to create data loaders. Override to customize the process. rst.collate_support_object_proxy is the default collating function.

By default, shuffle is enabled for training and disabled otherwise.

add_metric(self, metric: Metric)

Adds a Metric to track.

epoch(
    self,
    training: bool = False,
    epoch: Optional[int] = None,
    loader: Optional[DataLoader] = None,
    max_steps: Optional[int] = None,
    return_input: bool = False,
    return_pred: bool = False,
    quiet: bool = False
) -> ResultInterface

Runs an epoch.

You may specify the current status (training/validation) by setting training to True and False, respectively. This will affect the mode of the model (e.g., behavior of batch normalization and drop-out layers), as well as the data (training or valiation set).

epoch will be displayed in the progress bar if quiet is not set.

Override loader if you want to run the model on data other than the dataset provided by the Task.

You may limit the number of batches run with max_steps.

return_input and return_pred returns the feeds and/or predictions in .inputs and .preds as numpy arrays concatenated (as if the inputs/preds are run in a single batch).

Interfaces

torch.redstone interfaces are mainly designed around the training loop implementation, rst.DefaultLoop.

Tasks

A Task defines the data and metrics special to the task.

Interface requirements:

data(self) -> Tuple[Union[Dataset, DataLoader, list], Union[Dataset, DataLoader, list]]

You need to return a tuple of training and validation sets. They may be either torch.util.data.Dataset, torch.util.data.DataLoader instances or a list of data entries.

metrics(self) -> Sequence[Metric]

You need to return a sequence (tuple, list, etc.) of metrics. It is recommended that you place the metrics related with the task here like accuracy and loss components. Model-specific metrics to monitor can be assigned during the construction of the training loop.

Metrics

A Metric is a callable function on feed inputs and model outputs. It also has a name, which is default to the class name of the metric.

Interface requirements:

__call__(self, inputs, model_return) -> torch.Tensor

You need to return a single-element tensor as the metric result for inputs and the model return.

Optional overwrites:

@property
name(self) -> str

Override the name of the metric. You may also implement this by setting a name class attribute, or by setting an instance attribute, e.g., self.name.

torch.redstone contains implementations for binary and categorical accuracy (BinaryAcc and CategoricalAcc) that has a redstone() method to construct a Metric from, with a default name Acc. torch.redstone also adds a redstone() method to all the torch.nn losses that conveniently converts it into an instance of Metric:

redstone(self, name: str = ..., pred_path: AttrPathType = 'logits', label_path: AttrPathType = 'y')

Parameters:

  • name: the name of the metric. Defaults to Loss for torch.nn losses and Acc for BinaryAcc and CategoricalAcc.
  • pred_path: the prediction attribute that we should visit to apply the loss. Defaults to logits. For more details about AttrPathType, please see here.
  • label_path: the target attribute that we should visit to apply the loss. Defaults to y.

Example:

torch.nn.CrossEntropyLoss().redstone()

Losses

A Loss is similar to a Metric, but it can also access metric values so metrics can be easily aggregated into a loss function.

Interface requirements:

__call__(self, inputs, model_return, metrics) -> torch.Tensor

You need to return a single-element tensor. In most use cases, the loss should be differentiable.

The DefaultLoss returns metrics.loss, i.e. the Metric whose name is either LOSS, Loss or loss.

You may also implement a metric as both a Metric as a Loss, by giving a default value to metrics, e.g., None.

Example:

class MSELoss(rst.Loss, rst.Metric):
    def __call__(self, inputs, model_return, metrics=None) -> torch.Tensor:
        return F.mse_loss(model_return.logits, inputs.y)

torch.redstone adds a redstone() method to all the torch.nn losses. The return value of the method (described in detail in the Metric section) is both a Metric and a Loss.

Processors

A Processor receives callbacks from the training loop. It can access or modify various aspects of training. torch.redstone has several built-in processors: a training process logger Logger, a checkpoint saver BestSaver, as well as some training techniques like adversarial training AdvTrainingPGD.

Utility methods:

feed(self, model: nn.Module, inputs)

Feeds the inputs to the model, applying the Adapter. (Directly forwarding model does not apply the adapter, which may incur trouble sometimes.)

Attributes:

  • gscaler: torch.cuda.amp.GradScaler. If you need to run backward propagations (BP) in the processor, and you want to best support amp, you will need to scale the gradients by calling self.gscaler.scale on the target before you BP.

Available callbacks:

pre_forward(self, inputs, model: nn.Module) -> Any

Run before the forward of the model is called. inputs are transformed by Adapter before calling the method. If you return anything, the inputs will be replaced by the return value in the later processing chain as well as the model step.

post_forward(self, inputs, model: nn.Module, model_return) -> Any

Run after the forward pass of the model and before the backward pass. If you return anything, the model_return will be replaced by the return value in the later processing chain as well as the metric & loss computations and the backward pass.

pre_step(self, model: nn.Module, optimizer: torch.optim.Optimizer, metrics)

Run after the backward pass of the model before the optimizer has updated the model. New in 0.0.6.

post_step(self, model: nn.Module, optimizer: torch.optim.Optimizer, metrics)

Run after the backward pass of the model after the optimizer has updated the model.

pre_epoch(self, model: nn.Module, epoch: int)

Run before a new epoch begins.

post_epoch(self, model: nn.Module, epoch: int, epoch_result: EpochResultInterface)

Run after an epoch ends.

Adapters

Anything that can be implemented with adapters can also be implemented via processors. However, it is quite frequent that you may need to adapt the input/output protocol between various sources of models in a project. Adapter is the most convenient way to go.

Available transforms:

transform(self, inputs) -> Any

Transform the entries from the data loader. Default to an identity function. Processors see the result of this stage.

feed(self, net: nn.Module, inputs) -> Any

Feed the inputs into the network and return the outputs, possibly transforming before you call the net forward pass or after you have got the return value. The inputs have been transformed by the transform function, and processed by processors. Default to simply call network and return.

Utility Structures

ObjectProxy

It is recommended that you pass around and work with ObjectProxy in torch.redstone. It is basically a dict with attribute access. It is recommended to visit it with attribute access. The metrics object returned and used in interfaces is a type of ObjectProxy.

Example:

inputs = rst.ObjectProxy(x=2)
inputs.y = 3
print(inputs)
# ObjectProxy(x=2, y=3)

ObjectProxy can be nested, and you can also nest it with other kinds of containers.

Utility methods:

@classmethod
ObjectProxy.zip(cls, **kwargs) -> Generator[ObjectProxy]

zip(k1=v1, k2=v2, ...), where v1, ... are iterable into an ObjectProxy sequence.

Each ObjectProxy has attributes k1, k2, ... whose values are the corresponding items in v1, v2, ...

Example:

[*rst.ObjectProxy.zip(x=[1,2], y=[2,3])]
# [ObjectProxy(x=1, y=2), ObjectProxy(x=2, y=3)]

A good thing with ObjectProxy over a dict is that you can add type annotation to it. You will then have auto-completion and other IDE features available with ObjectProxy.

Example:

class ModelOutput:
    logits: torch.Tensor

...
class ActivationModule(torch.nn.Module):
    def forward(self, x) -> ModelOutput:
        return rst.ObjectProxy(logits=torch.softmax(x, dim=-1))

class Top5Acc(rst.Metric):
    name = "Acc@5"
    def __call__(self, inputs, model_return: ModelOutput) -> torch.Tensor:
        # Here you have auto-completion on `model_return`!

ResultInterface and EpochResultInterface

These are actually ObjectProxy instantiations. The definitions are:

class ResultInterface:
    metrics: ObjectProxy
    inputs: Optional[Any]
    preds: Optional[Any]


class EpochResultInterface:
    train: Optional[ResultInterface]
    val: Optional[ResultInterface]

They are return values of DefaultLoop.epoch and DefaultLoop.run, respectively. EpochResultInterface is also an input to the post_epoch callback of rst.Processor.

AttrPath and AttrPathType

In many places in torch.redstone, you will see AttrPathType. It is a representation for the attribute path in an object. In torch.redstone, it is widely used together with ObjectProxy.

AttrPath.output.logits, "output.logits" and lambda x: x.output.logits are all valid AttrPathType instances, and they refer to the same thing: torch.redstone will visit the x.output.logits attribute on a given object x with this AttrPathType.

None is a special AttrPathType instance. It refers to the object itself, without visiting any attribute, equivalent to lambda x: x.

Available Implemented Interfaces

Logger

rst.Logger(exp_name: str = "training", directory: str = "./logs/")

Subclass of rst.Processor. Implements the post_epoch callback.

A logger that logs all training and validation metrics to a .csv file. The file name would be exp_name_<timestamp>.csv The directory (and its parents) will be automatically created if not already existing.

Utility methods:

get_file_path(self) -> str

Returns the current file path to write logs to.

write_log(self, *data) -> None

Appends a row of data to the .csv file.

BestSaver

rst.BestSaver(
    metric: AttrPathType = "acc",
    model_name: str = "model", directory: str = "./logs/",
    lower_better: bool = False, verbose=1
)

Subclass of rst.Processor. Implements the post_epoch callback.

Saves the best model state_dict according to metric. The file name would be model_name_<timestamp>.dat. The directory (and its parents) will be automatically created if not already existing.

A str metric defaults to the validation metric with its name in lower-case and sanitized (non-ASCII characters are replaced by '_'). If the validation is turned off (run with val=False), the training metric is used. Otherwise, the AttrPathType refers to the complete EpochResultInterface.

You may pass lambda _: time.time() to always save the latest results.

If verbose is larger than or equal to 1, a message is printed when a new best is seen.

BestLossSaver

A specialized BestSaver with "loss" as the name of metric to watch, and lower_better set to True.

LatestSaver

rst.BestSaver(fmt = "model_{start_time}", cond = lambda epoch: True, directory = "./logs/")

Save latest checkpoints, with optional cond of checking according to epoch number (starting 0).

Return False in cond(epoch) to skip a save.

fmt can have parameters {start_time} and {epoch} to specialize save path for different training sessions and epochs. It supports python formatting, e.g., {epoch:04}.

AdvTrainingPGD

rst.AdvTrainingPGD(
    loss_metric: Metric,
    no_perturb_attrs: List[AttrPathType]=[],
    eps=0.03, step_scale=0.5, n_steps=8, attack_at_test=False
)

Subclass of rst.Processor. Implements the pre_forward callback.

Processor for L_inf PGD adversarial (robust) training.

Parameters:

  • loss_metric: the Metric to apply PGD on. In the most frequent use case this would be a cross entropy loss.
  • no_perturb_attrs: attributes to skip when applying perturbation. The path is relative to the inputs structure after Adapter transform.
  • eps: L_inf distance limit of perturbation.
  • step_scale: step size of PGD will be step_scale * eps.
  • n_steps: number of steps to go in PGD iteration.
  • attack_at_test: whether to perform the operation in validation passes.

GradientNormOperator

rst.GradientNormOperator(
    clip_norm: float,
    reject_norm: float
)

Subclass of rst.Processor. Implements the pre_step callback. New in 0.0.6.

Processor for clipping the gradient norm and rejecting update if gradient norm for the current iteration is too large.

Parameters:

  • clip_norm: the maximum scale of gradient to clip to.
  • reject_norm: the maximum scale of gradient that is allowed before clipping; otherwise the update would be rejected.

DirectPredictionAdapter

rst.DirectPredictionAdapter()

Subclass of rst.Adapter.

Adapter for a common protocol in PyTorch. Transforms (x, y) pairs from dataloder into ObjectProxy(x=x, y=y) and wraps a single model output into ObjectProxy(logits=output).

TorchMetric

rst.TorchMetric(
    torch_module: nn.Module, name = 'Loss',
    pred_path: AttrPathType = 'logits', label_path: AttrPathType = 'y'
)

Wraps a PyTorch module as a Loss and a Metric. See the documentation of .redstone() for parameters.

Hooks for Hacking Models

In many cases, especially in the initial exploration stage of your research, you may want to track the output from some layer in an existing model to add some regularizations or apply some analysis to get your insights. You may also want to modify the input to some layer of that model. However, the layer may be deep inside some nested structures of the implementation, and you end up copying the whole model implementation, tweaking that layer, and propagate the output all the way along the call stack, messing up the code base. The process is also time inefficient.

With these hooks provided by torch.redstone, you can perform these processes easily. The author has a concern currently, however, if more and more code are tweaked by hooks stacked over each other, they may become harder to read. This may be mitigated by code generation that applies the hacks automatically, but the feature is still under development. Currently, it is suggested to only do the hacking when exploring, and reorganize the code before publishing after you have settled to a final architecture.

The procedure for hooking usually consists of two stages: registration and execution. The registration is done during the initialization, and only once. The execution stage runs every time you call the model.

catch_input_to

rst.catch_input_to(module: nn.Module, check_catch_count: Optional[int] = 1)

Registers a hook to catch inputs to the module. Returns a hook structure.

After a forward step is run in the module, you may call .get() on the hook structure to obtain the forward. If check_catch_count is a positive number, checks that exactly such count of forward is called before you get(). You may get more than once if more than one forward passes is run. The result will be catched in execution order.

Only positional arguments are caught. If the input to the module has more than 1 positional arguments, the result would be a tuple. Otherwise, it would be that single argument object.

You may call dispose() to release the catch. It will also automatically release itself if you no longer keeps a reference to the hook structure.

catch_output_from

rst.catch_output_from(module: nn.Module, check_catch_count: Optional[int] = 1, raise_stop_execution: bool = False)

Similar to catch_input_to, with an additional parameter raise_stop_execution to stop the execution by raising an exception (e.g., in a big model where you only want the outputs from some early layer). It is recommended to use it together with catching_scope().

Example:

tch = Net()  # teacher model
# registration stage
tch_catch = rst.catch_output_from(tch.feat.conv3, raise_stop_execution=True)
# execution stage in forward of another model
with rst.catching_scope():
    tch(x)  # stops once we have outputs from the tch.feat.conv3 layer
tch_feat = tch_catch.get()  # gets hooked output features from the conv3 layer

modify_input_to

rst.modify_input_to(module: nn.Module, apply_func)

Modifies the input to the module by applying apply_func. You only receive the positional arguments to the forward call in apply_func. You may return a single object or a tuple, if there are one or more positional arguments to the forward pass, respectively.

You may call .dispose() on the returned structure to remove the modification.

Utility Functions

seed

rst.seed(seed: int)

Seeds Python random, numpy.random and PyTorch RNG like torch.rand and torch.randn.

Equivalent to:

torch.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)

torch_to_numpy

rst.torch_to_numpy(data: Any, strict = False)

Recursively fetch torch.Tensor in list, dict, ObjectProxy, tuple, set as numpy array. Different from Tensor.numpy(), the function also supports tensors on GPU and in gradient graphs.

Any element that is neither a recognized container nor an object with .detach method is returned as-is by default. An error will be raised in this case if strict is set to True.

torch_to

rst.torch_to(data: Any, reference: Union[str, torch.device, torch.Tensor], strict = False)

Recursively send torch.Tensor in list, dict, ObjectProxy, tuple, set to reference (calls .to on each element).

Any element that is neither a recognized container nor an object with .to method is returned as-is by default. An error will be raised in this case if strict is set to True.

supercat

rst.supercat(tensors: Sequence[Tensor], dim: int = 0)

Similar to torch.cat, but supports broadcasting. If the tensors have different number of dimensions, singleton dimensions are prepended first to align the dimensions.

Example:

import torch
import torch_redstone as rst
a = torch.randn(8, 32)
b = torch.randn(16, 1, 64)
c = rst.supercat([a, b], 2)
print(c.shape)
# torch.Size([16, 8, 96])

The supercat executes as follows:

  1. A singleton dimension is prepended to a into shape [1, 8, 32].
  2. a and b are broadcast except for the dimension (2) to be concatenated. Now they are of shape [16, 8, 32] and [16, 8, 64].
  3. The tensors are concatenated on dimension 2.

xcat

rst.xcat(tensors: Sequence[Tensor], dim: int = 0)

An eXtended torch.cat that supports broadcasting. Alias of rst.supercat.

xreshape

rst.xreshape(tensor: torch.Tensor, shape: Sequence[int], s: Optional[int], e: Optional[int], dim: Optional[int])

Similar to torch.reshape, but supports reshaping a section (s-th dim to e-th dim, both ends included) of the shape.

If dim is set, s and e will be set to dim. An error will be raised if both dim and s or e is set. If either s or e is None, the reshape will start from the beginning or go through the end of the shape, respectively.

Sample shapes:

[K, A * B, C] -- xreshape [A, B] dim 1 --> [K, A, B, C]
[K, A * B, C * D] -- xreshape [A, -1, D] s -2 --> [K, A, B * C, D]

visit_attr

rst.visit_attr(Q, attr: AttrPathType) -> Any

Visits the attribute specified by attr in object Q. See the section on AttrPathType for details about attr.

sanitize_name

rst.sanitize_name(name: str) -> str

Replaces any character that is not ASCII letters or digits with '_'.

take_first

rst.take_first(iterable, n: int) -> iter

Takes the first n elements of the iterable as a new iterator.

Math Functions

log1m_exp

rst.log1m_exp(arr_x: Tensor) -> Tensor

Computes x -> log(1 - exp(x)) in a numerically stable manner.

Utility NN Modules

MLP

rst.MLP(sizes: List[int], n_group_dims = 0, activation: = F.relu, norm = 'batch')

Subclass of torch.nn.Module.

Basic multi-layer perceptron (MLP), also called dense or fully connected (FC) networks. Normalization and activations are configurable.

Note: The output layer is also normalized and activated.

Parameters:

  • sizes: sizes of layers, including the input. [n_in, h_1, h_2, ..., n_out].
  • n_group_dims: 0 if input is of shape [B, C], 1 if [B, C, N], 2 if [B, C, H, W], 3 if [B, C, D, H, W].
  • activation: a Tensor -> Tensor function. Defaults to torch.relu.
  • norm: 'batch', 'instance', or None. Normalization layers type.

Lambda

rst.Lambda(lam: Callable)

Subclass of torch.nn.Module.

Executes a lambda function lam in forward.

Example:

import torch
import torch_redstone as rst

net = torch.nn.Sequential(rst.Lambda(lambda x: x * 2))
print(net(torch.ones(5)))
# tensor([2., 2., 2., 2., 2.])

GetItem

rst.GetItem(index: int)

Subclass of torch.nn.Module.

Indexes the input to the module and returns the result.

BinaryAcc

rst.BinaryAcc(th_logits = 0.0, th_label = 0.5)

Subclass of torch.nn.Module.

Computes the accuracy of a binary classification in forward(self, inputs, targets). The classification threshold of model output is th_logits. Threshold of the label is th_label. Has a .redstone() method same as those described in the Metric section.

CategoricalAcc

rst.CategoricalAcc(dim = 1)

Subclass of torch.nn.Module.

Computes the accuracy of a multi-classification in forward(self, inputs, targets). The classification logits is along the dim (the size of the dim should be the number of classes). The targets should be integers specifying the class labels instead of using a one-hot encoding. Has a .redstone() method same as those described in the Metric section.

Utility Classes

Meter

rst.Meter()

A meter utility to track the average statistics (loss, accuracy, etc.) during an epoch.

Example:

import torch
import torch_redstone as rst

meter = rst.Meter()
meter.u('acc', 1)
meter.u('acc', 0)
print(meter['acc'])
# 0.5

Polyfills

Polyfills are implemented in torch_redstone/polyfill.py in the class Polyfill to make available new functions in PyTorch for older PyTorch versions.

The available polyfills are:

PyTorch Function PyTorch Version
torch.cdist 1.3
torch.square 1.5
torch.absolute 1.6
torch.arccos 1.7
torch.arcsin 1.7
torch.arctan 1.7
torch.broadcast_to 1.8
torch.autocast 1.10