After you create a rewritten model using our rewriter, it's better to write a unit test for the model to validate if the model rewrite would come into effect. Generally, we need to get outputs of the original model and rewritten model, then compare them. The outputs of the original model can be acquired directly by calling the forward function of the model, whereas the way to generate the outputs of the rewritten model depends on the complexity of the rewritten model.
If the changes to the model are small (e.g., only change the behavior of one or two variables and don't introduce side effects), you can construct the input arguments for the rewritten functions/modules,run model's inference in RewriteContext
and check the results.
# mmpretrain.models.classfiers.base.py
class BaseClassifier(BaseModule, metaclass=ABCMeta):
def forward(self, img, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(img, **kwargs)
else:
return self.forward_test(img, **kwargs)
# Custom rewritten function
@FUNCTION_REWRITER.register_rewriter(
'mmpretrain.models.classifiers.BaseClassifier.forward', backend='default')
def forward_of_base_classifier(self, img, *args, **kwargs):
"""Rewrite `forward` for default backend."""
return self.simple_test(img, {})
In the example, we only change the function that forward
calls. We can test this rewritten function by writing the following test function:
def test_baseclassfier_forward():
input = torch.rand(1)
from mmpretrain.models.classifiers import BaseClassifier
class DummyClassifier(BaseClassifier):
def __init__(self, init_cfg=None):
super().__init__(init_cfg=init_cfg)
def extract_feat(self, imgs):
pass
def forward_train(self, imgs):
return 'train'
def simple_test(self, img, tmp, **kwargs):
return 'simple_test'
model = DummyClassifier().eval()
model_output = model(input)
with RewriterContext(cfg=dict()), torch.no_grad():
backend_output = model(input)
assert model_output == 'train'
assert backend_output == 'simple_test'
In this test function, we construct a derived class of BaseClassifier
to test if the rewritten model would work in the rewrite context. We get outputs of the original model by directly calling model(input)
and get the outputs of the rewritten model by calling model(input)
in RewriteContext
. Finally, we can check the outputs by asserting their value.
In the first example, the output is generated in Python. Sometimes we may make big changes to original model functions (e.g., eliminate branch statements to generate correct computing graph). Even if the outputs of a rewritten model running in Python are correct, we cannot assure that the rewritten model can work as expected in the backend. Therefore, we need to test the rewritten model in the backend.
# Custom rewritten function
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
def base_segmentor__forward(self, img, img_metas=None, **kwargs):
ctx = FUNCTION_REWRITER.get_context()
if img_metas is None:
img_metas = {}
assert isinstance(img_metas, dict)
assert isinstance(img, torch.Tensor)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
img_shape = img.shape[2:]
if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape]
img_metas['img_shape'] = img_shape
return self.simple_test(img, img_metas, **kwargs)
The behavior of this rewritten function is complex. We should test it as follows:
def test_basesegmentor_forward():
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
get_rewrite_outputs)
segmentor = get_model()
segmentor.cpu().eval()
# Prepare data
# ...
# Get the outputs of original model
model_inputs = {
'img': [imgs],
'img_metas': [img_metas],
'return_loss': False
}
model_outputs = get_model_outputs(segmentor, 'forward', model_inputs)
# Get the outputs of rewritten model
wrapped_model = WrapModel(segmentor, 'forward', img_metas = None, return_loss = False)
rewrite_inputs = {'img': imgs}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
# If the backend plugins have been installed, the rewrite outputs are
# generated by backend.
rewrite_outputs = torch.tensor(rewrite_outputs)
model_outputs = torch.tensor(model_outputs)
model_outputs = model_outputs.unsqueeze(0).unsqueeze(0)
assert torch.allclose(rewrite_outputs, model_outputs)
else:
# Otherwise, the outputs are generated by python.
assert rewrite_outputs is not None
We provide some utilities to test rewritten functions. At first, you can construct a model and call get_model_outputs
to get outputs of the original model. Then you can wrap the rewritten function with WrapModel
, which serves as a partial function, and get the results with get_rewrite_outputs
. get_rewrite_outputs
returns two values that indicate the content of outputs and whether the outputs come from the backend. Because we cannot assume that everyone has installed the backend, we should check if the results are generated by a Python or backend engine. The unit test must cover both conditions. Finally, we should compare the original and rewritten outputs, which may be done simply by calling torch.allclose
.
To learn the complete usage of the test utilities, please refer to our apis document.