-
Notifications
You must be signed in to change notification settings - Fork 0
API Reference
rst.DefaultLoop( model: torch.nn.Module, task: Task, loss: Optional[Loss] = None, metrics: Optional[Sequence[Union[Metric, EllipsisType]]] = None, processors: Sequence[Processor] = None, optimizer: Union[str, torch.optim.Optimizer] = 'adam', adapter: Adapter = Adapter(), scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, scheduler_base: Literal['epoch', 'step'] = 'step', *, batch_size=32, num_workers=0, amp=False )
This is a central class in torch.redstone, the default training loop. In most cases, it would be sufficient to use this class to train and test models. If you need to do more customization, you may also derive a new class that inherits this DefaultLoop.
Parameters:
-
model: the target model in operation. -
task: the target task. Please see the interfaces section for details about tasks. -
loss: the loss for training. If set toNone(the default), theDefaultLossis applied, which finds theMetricwhose name isloss(case insensitive). Please see the interfaces section for details about losses. -
metrics: the metrics to record and monitor. If set toNone(the default), the metrics defined by theTaskwill be applied. Otherwise, it should be aMetriclist. The list may also contain a...which stands for the metrics defined by theTask. Default displayed name for a metric is its class name. This can be overridden by settingmetric.name. In returnedObjectProxyfor metrics, attribute names are lower-case displayed names. Please see the interfaces section for details about metrics. -
processors: the list of processors. If set toNone(the default), a defaultLoggerwill be applied. Please see the interfaces section for details about processors. -
optimizer: atorch.optim.Optimizerinstance or astr:sgd,adam,rmsproporadadelta. Thesgdis theSGDoptimizer with learning rate0.01and nesterov momentum0.9. Other strings refer to a default parameters (suggested in relevant papers) of that optimizer. -
adapter: an adapter intended to offer a chance to unify input/output interfaces of different models (maybe from different sources) without having to make a wrapper around them. Please see the interfaces section for details about adapters. -
scheduler: a learning rate scheduler. An example is the cyclic learning rate (CLR) scheduler:torch.optim.lr_scheduler.CyclicLR(optimizer, 0.0, 0.2). -
scheduler_base: the base step of the learning rate scheduler. According to the value (epochorstep), the learning rate schedule will be advanced each epoch or each iteration step, respectively. -
batch_size: the batch size used to create the data loaders. Ignored if theTaskreturns data loaders already. Keyword-only argument. Defaults to32. -
num_workers: the number of worker processes of the data loaders. Ignored if theTaskreturns data loaders already. Keyword-only argument. Defaults to0(data loaders run in the main process). -
amp: enable or disable automatic mixed precision. Various operations would be converted intoFP16operations instead of the defaultFP32to save memory, and also time if you have a moderately new CUDA GPU with tensor cores. In a few cases, e.g., empirically, precision tasks like image super resolution, you may see a significant decrease in model performance metrics withamp. Otherwise, it would be great to enableamp. Keyword-only argument. Defaults toFalse. It is recommended to specifytorch.bfloat16ortorch.float16explicitly in this argument on supported GPUs.
Methods:
run(self, num_epochs, train=True, val=True, max_steps=None, quiet=False)
Runs the loop for num_epochs epochs.
You may set to skip training or valiation by setting train and val to False, respectively.
You may limit the maximum steps (batches) to run by setting max_steps.
A progress bar will be displayed along with average statistics if quiet is False.
create_data_loader(self, data: Union[Dataset, list], is_train: bool)
Called to create data loaders. Override to customize the process. rst.collate_support_object_proxy is the default collating function.
By default, shuffle is enabled for training and disabled otherwise.
add_metric(self, metric: Metric)
Adds a Metric to track.
epoch( self, training: bool = False, epoch: Optional[int] = None, loader: Optional[DataLoader] = None, max_steps: Optional[int] = None, return_input: bool = False, return_pred: bool = False, quiet: bool = False ) -> ResultInterface
Runs an epoch.
You may specify the current status (training/validation) by setting training to True and False, respectively. This will affect the mode of the model (e.g., behavior of batch normalization and drop-out layers), as well as the data (training or valiation set).
epoch will be displayed in the progress bar if quiet is not set.
Override loader if you want to run the model on data other than the dataset provided by the Task.
You may limit the number of batches run with max_steps.
return_input and return_pred returns the feeds and/or predictions in .inputs and .preds as numpy arrays concatenated (as if the inputs/preds are run in a single batch).
torch.redstone interfaces are mainly designed around the training loop implementation, rst.DefaultLoop.
A Task defines the data and metrics special to the task.
Interface requirements:
data(self) -> Tuple[Union[Dataset, DataLoader, list], Union[Dataset, DataLoader, list]]
You need to return a tuple of training and validation sets. They may be either torch.util.data.Dataset, torch.util.data.DataLoader instances or a list of data entries.
metrics(self) -> Sequence[Metric]
You need to return a sequence (tuple, list, etc.) of metrics. It is recommended that you place the metrics related with the task here like accuracy and loss components. Model-specific metrics to monitor can be assigned during the construction of the training loop.
A Metric is a callable function on feed inputs and model outputs. It also has a name, which is default to the class name of the metric.
Interface requirements:
__call__(self, inputs, model_return) -> torch.Tensor
You need to return a single-element tensor as the metric result for inputs and the model return.
Optional overwrites:
@property name(self) -> str
Override the name of the metric. You may also implement this by setting a name class attribute, or by setting an instance attribute, e.g., self.name.
torch.redstone contains implementations for binary and categorical accuracy (BinaryAcc and CategoricalAcc) that has a redstone() method to construct a Metric from, with a default name Acc.
torch.redstone also adds a redstone() method to all the torch.nn losses that conveniently converts it into an instance of Metric:
redstone(self, name: str = ..., pred_path: AttrPathType = 'logits', label_path: AttrPathType = 'y')
Parameters:
-
name: the name of the metric. Defaults toLossfortorch.nnlosses andAccforBinaryAccandCategoricalAcc. -
pred_path: the prediction attribute that we should visit to apply the loss. Defaults tologits. For more details aboutAttrPathType, please see here. -
label_path: the target attribute that we should visit to apply the loss. Defaults toy.
Example:
torch.nn.CrossEntropyLoss().redstone()A Loss is similar to a Metric, but it can also access metric values so metrics can be easily aggregated into a loss function.
Interface requirements:
__call__(self, inputs, model_return, metrics) -> torch.Tensor
You need to return a single-element tensor. In most use cases, the loss should be differentiable.
The DefaultLoss returns metrics.loss, i.e. the Metric whose name is either LOSS, Loss or loss.
You may also implement a metric as both a Metric as a Loss, by giving a default value to metrics, e.g., None.
Example:
class MSELoss(rst.Loss, rst.Metric):
def __call__(self, inputs, model_return, metrics=None) -> torch.Tensor:
return F.mse_loss(model_return.logits, inputs.y)torch.redstone adds a redstone() method to all the torch.nn losses. The return value of the method (described in detail in the Metric section) is both a Metric and a Loss.
A Processor receives callbacks from the training loop. It can access or modify various aspects of training. torch.redstone has several built-in processors: a training process logger Logger, a checkpoint saver BestSaver, as well as some training techniques like adversarial training AdvTrainingPGD.
Utility methods:
feed(self, model: nn.Module, inputs)
Feeds the inputs to the model, applying the Adapter. (Directly forwarding model does not apply the adapter, which may incur trouble sometimes.)
Attributes:
-
gscaler:torch.cuda.amp.GradScaler. If you need to run backward propagations (BP) in the processor, and you want to best supportamp, you will need to scale the gradients by callingself.gscaler.scaleon the target before you BP.
Available callbacks:
pre_forward(self, inputs, model: nn.Module) -> Any
Run before the forward of the model is called. inputs are transformed by Adapter before calling the method. If you return anything, the inputs will be replaced by the return value in the later processing chain as well as the model step.
post_forward(self, inputs, model: nn.Module, model_return) -> Any
Run after the forward pass of the model and before the backward pass. If you return anything, the model_return will be replaced by the return value in the later processing chain as well as the metric & loss computations and the backward pass.
pre_step(self, model: nn.Module, optimizer: torch.optim.Optimizer, metrics)
Run after the backward pass of the model before the optimizer has updated the model. New in 0.0.6.
post_step(self, model: nn.Module, optimizer: torch.optim.Optimizer, metrics)
Run after the backward pass of the model after the optimizer has updated the model.
pre_epoch(self, model: nn.Module, epoch: int)
Run before a new epoch begins.
post_epoch(self, model: nn.Module, epoch: int, epoch_result: EpochResultInterface)
Run after an epoch ends.
Anything that can be implemented with adapters can also be implemented via processors. However, it is quite frequent that you may need to adapt the input/output protocol between various sources of models in a project. Adapter is the most convenient way to go.
Available transforms:
transform(self, inputs) -> Any
Transform the entries from the data loader. Default to an identity function. Processors see the result of this stage.
feed(self, net: nn.Module, inputs) -> Any
Feed the inputs into the network and return the outputs, possibly transforming before you call the net forward pass or after you have got the return value. The inputs have been transformed by the transform function, and processed by processors. Default to simply call network and return.
It is recommended that you pass around and work with ObjectProxy in torch.redstone. It is basically a dict with attribute access. It is recommended to visit it with attribute access. The metrics object returned and used in interfaces is a type of ObjectProxy.
Example:
inputs = rst.ObjectProxy(x=2)
inputs.y = 3
print(inputs)
# ObjectProxy(x=2, y=3)ObjectProxy can be nested, and you can also nest it with other kinds of containers.
Utility methods:
@classmethod ObjectProxy.zip(cls, **kwargs) -> Generator[ObjectProxy]
zip(k1=v1, k2=v2, ...), where v1, ... are iterable into an ObjectProxy sequence.
Each ObjectProxy has attributes k1, k2, ... whose values are
the corresponding items in v1, v2, ...
Example:
[*rst.ObjectProxy.zip(x=[1,2], y=[2,3])]
# [ObjectProxy(x=1, y=2), ObjectProxy(x=2, y=3)]A good thing with ObjectProxy over a dict is that you can add type annotation to it. You will then have auto-completion and other IDE features available with ObjectProxy.
Example:
class ModelOutput:
logits: torch.Tensor
...
class ActivationModule(torch.nn.Module):
def forward(self, x) -> ModelOutput:
return rst.ObjectProxy(logits=torch.softmax(x, dim=-1))
class Top5Acc(rst.Metric):
name = "Acc@5"
def __call__(self, inputs, model_return: ModelOutput) -> torch.Tensor:
# Here you have auto-completion on `model_return`!These are actually ObjectProxy instantiations. The definitions are:
class ResultInterface:
metrics: ObjectProxy
inputs: Optional[Any]
preds: Optional[Any]
class EpochResultInterface:
train: Optional[ResultInterface]
val: Optional[ResultInterface]They are return values of DefaultLoop.epoch and DefaultLoop.run, respectively. EpochResultInterface is also an input to the post_epoch callback of rst.Processor.
In many places in torch.redstone, you will see AttrPathType. It is a representation for the attribute path in an object. In torch.redstone, it is widely used together with ObjectProxy.
AttrPath.output.logits, "output.logits" and lambda x: x.output.logits are all valid AttrPathType instances, and they refer to the same thing: torch.redstone will visit the x.output.logits attribute on a given object x with this AttrPathType.
None is a special AttrPathType instance. It refers to the object itself, without visiting any attribute, equivalent to lambda x: x.
rst.Logger(exp_name: str = "training", directory: str = "./logs/")
Subclass of rst.Processor. Implements the post_epoch callback.
A logger that logs all training and validation metrics to a .csv file.
The file name would be exp_name_<timestamp>.csv
The directory (and its parents) will be automatically created if not already existing.
Utility methods:
get_file_path(self) -> str
Returns the current file path to write logs to.
write_log(self, *data) -> None
Appends a row of data to the .csv file.
rst.BestSaver( metric: AttrPathType = "acc", model_name: str = "model", directory: str = "./logs/", lower_better: bool = False, verbose=1 )
Subclass of rst.Processor. Implements the post_epoch callback.
Saves the best model state_dict according to metric. The file name would be model_name_<timestamp>.dat. The directory (and its parents) will be automatically created if not already existing.
A str metric defaults to the validation metric with its name in lower-case and sanitized (non-ASCII characters are replaced by '_'). If the validation is turned off (run with val=False), the training metric is used. Otherwise, the AttrPathType refers to the complete EpochResultInterface.
You may pass lambda _: time.time() to always save the latest results.
If verbose is larger than or equal to 1, a message is printed when a new best is seen.
A specialized BestSaver with "loss" as the name of metric to watch, and lower_better set to True.
rst.BestSaver(fmt = "model_{start_time}", cond = lambda epoch: True, directory = "./logs/")
Save latest checkpoints, with optional cond of checking according to epoch number (starting 0).
Return False in cond(epoch) to skip a save.
fmt can have parameters {start_time} and {epoch} to specialize save path for different training sessions and epochs.
It supports python formatting, e.g., {epoch:04}.
rst.AdvTrainingPGD( loss_metric: Metric, no_perturb_attrs: List[AttrPathType]=[], eps=0.03, step_scale=0.5, n_steps=8, attack_at_test=False )
Subclass of rst.Processor. Implements the pre_forward callback.
Processor for L_inf PGD adversarial (robust) training.
Parameters:
-
loss_metric: theMetricto apply PGD on. In the most frequent use case this would be a cross entropy loss. -
no_perturb_attrs: attributes to skip when applying perturbation. The path is relative to the inputs structure afterAdaptertransform. -
eps: L_inf distance limit of perturbation. -
step_scale: step size of PGD will bestep_scale * eps. -
n_steps: number of steps to go in PGD iteration. -
attack_at_test: whether to perform the operation in validation passes.
rst.GradientNormOperator( clip_norm: float, reject_norm: float )
Subclass of rst.Processor. Implements the pre_step callback. New in 0.0.6.
Processor for clipping the gradient norm and rejecting update if gradient norm for the current iteration is too large.
Parameters:
-
clip_norm: the maximum scale of gradient to clip to. -
reject_norm: the maximum scale of gradient that is allowed before clipping; otherwise the update would be rejected.
rst.DirectPredictionAdapter()
Subclass of rst.Adapter.
Adapter for a common protocol in PyTorch. Transforms (x, y) pairs from dataloder into ObjectProxy(x=x, y=y) and wraps a single model output into ObjectProxy(logits=output).
rst.TorchMetric( torch_module: nn.Module, name = 'Loss', pred_path: AttrPathType = 'logits', label_path: AttrPathType = 'y' )
Wraps a PyTorch module as a Loss and a Metric. See the documentation of .redstone() for parameters.
In many cases, especially in the initial exploration stage of your research, you may want to track the output from some layer in an existing model to add some regularizations or apply some analysis to get your insights. You may also want to modify the input to some layer of that model. However, the layer may be deep inside some nested structures of the implementation, and you end up copying the whole model implementation, tweaking that layer, and propagate the output all the way along the call stack, messing up the code base. The process is also time inefficient.
With these hooks provided by torch.redstone, you can perform these processes easily. The author has a concern currently, however, if more and more code are tweaked by hooks stacked over each other, they may become harder to read. This may be mitigated by code generation that applies the hacks automatically, but the feature is still under development. Currently, it is suggested to only do the hacking when exploring, and reorganize the code before publishing after you have settled to a final architecture.
The procedure for hooking usually consists of two stages: registration and execution. The registration is done during the initialization, and only once. The execution stage runs every time you call the model.
rst.catch_input_to(module: nn.Module, check_catch_count: Optional[int] = 1)
Registers a hook to catch inputs to the module. Returns a hook structure.
After a forward step is run in the module, you may call .get() on the hook structure to obtain the forward. If check_catch_count is a positive number, checks that exactly such count of forward is called before you get(). You may get more than once if more than one forward passes is run. The result will be catched in execution order.
Only positional arguments are caught. If the input to the module has more than 1 positional arguments, the result would be a tuple. Otherwise, it would be that single argument object.
You may call dispose() to release the catch. It will also automatically release itself if you no longer keeps a reference to the hook structure.
rst.catch_output_from(module: nn.Module, check_catch_count: Optional[int] = 1, raise_stop_execution: bool = False)
Similar to catch_input_to, with an additional parameter raise_stop_execution to stop the execution by raising an exception (e.g., in a big model where you only want the outputs from some early layer). It is recommended to use it together with catching_scope().
Example:
tch = Net() # teacher model
# registration stage
tch_catch = rst.catch_output_from(tch.feat.conv3, raise_stop_execution=True)# execution stage in forward of another model
with rst.catching_scope():
tch(x) # stops once we have outputs from the tch.feat.conv3 layer
tch_feat = tch_catch.get() # gets hooked output features from the conv3 layerrst.modify_input_to(module: nn.Module, apply_func)
Modifies the input to the module by applying apply_func. You only receive the positional arguments to the forward call in apply_func. You may return a single object or a tuple, if there are one or more positional arguments to the forward pass, respectively.
You may call .dispose() on the returned structure to remove the modification.
rst.seed(seed: int)
Seeds Python random, numpy.random and PyTorch RNG like torch.rand and torch.randn.
Equivalent to:
torch.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)rst.torch_to_numpy(data: Any, strict = False)
Recursively fetch torch.Tensor in list, dict, ObjectProxy, tuple, set as numpy array. Different from Tensor.numpy(), the function also supports tensors on GPU and in gradient graphs.
Any element that is neither a recognized container nor an object with .detach method is returned as-is by default. An error will be raised in this case if strict is set to True.
rst.torch_to(data: Any, reference: Union[str, torch.device, torch.Tensor], strict = False)
Recursively send torch.Tensor in list, dict, ObjectProxy, tuple, set to reference (calls .to on each element).
Any element that is neither a recognized container nor an object with .to method is returned as-is by default. An error will be raised in this case if strict is set to True.
rst.supercat(tensors: Sequence[Tensor], dim: int = 0)
Similar to torch.cat, but supports broadcasting. If the tensors have different number of dimensions,
singleton dimensions are prepended first to align the dimensions.
Example:
import torch
import torch_redstone as rst
a = torch.randn(8, 32)
b = torch.randn(16, 1, 64)
c = rst.supercat([a, b], 2)
print(c.shape)
# torch.Size([16, 8, 96])The supercat executes as follows:
- A singleton dimension is prepended to
ainto shape[1, 8, 32]. -
aandbare broadcast except for the dimension (2) to be concatenated. Now they are of shape[16, 8, 32]and[16, 8, 64]. - The tensors are concatenated on dimension 2.
rst.xcat(tensors: Sequence[Tensor], dim: int = 0)
An eXtended torch.cat that supports broadcasting. Alias of rst.supercat.
rst.xreshape(tensor: torch.Tensor, shape: Sequence[int], s: Optional[int], e: Optional[int], dim: Optional[int])
Similar to torch.reshape, but supports reshaping a section (s-th dim to e-th dim, both ends included) of the shape.
If dim is set, s and e will be set to dim. An error will be raised if both dim and s or e is set.
If either s or e is None, the reshape will start from the beginning or go through the end of the shape, respectively.
Sample shapes:
[K, A * B, C] -- xreshape [A, B] dim 1 --> [K, A, B, C]
[K, A * B, C * D] -- xreshape [A, -1, D] s -2 --> [K, A, B * C, D]
rst.visit_attr(Q, attr: AttrPathType) -> Any
Visits the attribute specified by attr in object Q. See the section on AttrPathType for details about attr.
rst.sanitize_name(name: str) -> str
Replaces any character that is not ASCII letters or digits with '_'.
rst.take_first(iterable, n: int) -> iter
Takes the first n elements of the iterable as a new iterator.
rst.log1m_exp(arr_x: Tensor) -> Tensor
Computes x -> log(1 - exp(x)) in a numerically stable manner.
rst.MLP(sizes: List[int], n_group_dims = 0, activation: = F.relu, norm = 'batch')
Subclass of torch.nn.Module.
Basic multi-layer perceptron (MLP), also called dense or fully connected (FC) networks. Normalization and activations are configurable.
Note: The output layer is also normalized and activated.
Parameters:
-
sizes: sizes of layers, including the input.[n_in, h_1, h_2, ..., n_out]. -
n_group_dims: 0 if input is of shape[B, C], 1 if[B, C, N], 2 if[B, C, H, W], 3 if[B, C, D, H, W]. -
activation: aTensor -> Tensorfunction. Defaults totorch.relu. -
norm: 'batch', 'instance', orNone. Normalization layers type.
rst.Lambda(lam: Callable)
Subclass of torch.nn.Module.
Executes a lambda function lam in forward.
Example:
import torch
import torch_redstone as rst
net = torch.nn.Sequential(rst.Lambda(lambda x: x * 2))
print(net(torch.ones(5)))
# tensor([2., 2., 2., 2., 2.])rst.GetItem(index: int)
Subclass of torch.nn.Module.
Indexes the input to the module and returns the result.
rst.BinaryAcc(th_logits = 0.0, th_label = 0.5)
Subclass of torch.nn.Module.
Computes the accuracy of a binary classification in forward(self, inputs, targets). The classification threshold of model output is th_logits. Threshold of the label is th_label. Has a .redstone() method same as those described in the Metric section.
rst.CategoricalAcc(dim = 1)
Subclass of torch.nn.Module.
Computes the accuracy of a multi-classification in forward(self, inputs, targets). The classification logits is along the dim (the size of the dim should be the number of classes). The targets should be integers specifying the class labels instead of using a one-hot encoding. Has a .redstone() method same as those described in the Metric section.
rst.Meter()
A meter utility to track the average statistics (loss, accuracy, etc.) during an epoch.
Example:
import torch
import torch_redstone as rst
meter = rst.Meter()
meter.u('acc', 1)
meter.u('acc', 0)
print(meter['acc'])
# 0.5Polyfills are implemented in torch_redstone/polyfill.py in the class Polyfill to make available new functions in PyTorch for older PyTorch versions.
The available polyfills are:
| PyTorch Function | PyTorch Version |
|---|---|
torch.cdist |
1.3 |
torch.square |
1.5 |
torch.absolute |
1.6 |
torch.arccos |
1.7 |
torch.arcsin |
1.7 |
torch.arctan |
1.7 |
torch.broadcast_to |
1.8 |
torch.autocast |
1.10 |