Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to set different patch size between training stage and inference stage #2639

Open
SwimmingLiu opened this issue Dec 8, 2024 · 0 comments
Assignees

Comments

@SwimmingLiu
Copy link

Hi~ Thanks for your wonderful work!
I wanna set different patch size between trainning stage and inference stage. How can I do this?
For example, There is 2d configuration which is generated by nnUNetv2_preprocess

 "configurations": {
        "2d": {
            "data_identifier": "nnUNetPlans_2d",
            "preprocessor_name": "DefaultPreprocessor",
            "batch_size": 2,
            "patch_size": [
                1024,
                1280
            ],

After I trained it, I got fold 0 ~ fold 4. No errors there of course
Then I reset the patch size, and make it larger

 "configurations": {
        "2d": {
            "data_identifier": "nnUNetPlans_2d",
            "preprocessor_name": "DefaultPreprocessor",
            "batch_size": 2,
            "patch_size": [
                1080,
                1280
            ],

Finally, I evaluate all validation sets of each fold, use this command

 nnUNetv2_train 918 2d 0  --val --npz

But things got worse, it caused this error

2024-12-08 12:19:38.169974: 01.0000000111019.20012.0019.09345900518, shape torch.Size([3, 1, 1080, 1265]), rank 0
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0] failed while attempting to run meta for aten.cat.default
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0] Traceback (most recent call last):
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]   File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2013, in _dispatch_impl
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]     r = func(*args, **kwargs)
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]         ^^^^^^^^^^^^^^^^^^^^^
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]   File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_ops.py", line 716, in __call__
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]     return self._op(*args, **kwargs)
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]   File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]     result = fn(*args, **kwargs)
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]              ^^^^^^^^^^^^^^^^^^^
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]   File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_prims_common/wrappers.py", line 141, in _fn
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]     result = fn(**bound.arguments)
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]              ^^^^^^^^^^^^^^^^^^^^^
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]   File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_refs/__init__.py", line 2832, in cat
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]     return prims.cat(filtered, dim).clone(memory_format=memory_format)
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^
E1208 12:19:43.123000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/0]   File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_ops.py", line 716, in __call__
.............................................................................
E1208 12:19:47.963000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/1]     raise error_type(message_evaluated)
E1208 12:19:47.963000 202746 site-packages/torch/_subclasses/fake_tensor.py:2017] [0/1] RuntimeError: Sizes of tensors must match except in dimension 1. Expected 10 but got 9 for tensor number 1 in the list
Traceback (most recent call last):
  File "/opt/miniconda3/envs/nnUNet/bin/nnUNetv2_train", line 8, in <module>
    sys.exit(run_training_entry())
             ^^^^^^^^^^^^^^^^^^^^
  File "/root/nnUNet/nnunetv2/run/run_training.py", line 275, in run_training_entry
    run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
  File "/root/nnUNet/nnunetv2/run/run_training.py", line 215, in run_training
    nnunet_trainer.perform_actual_validation(export_validation_probabilities)
  File "/root/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py", line 1286, in perform_actual_validation
    prediction = predictor.predict_sliding_window_return_logits(data)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/nnUNet/nnunetv2/inference/predict_from_raw_data.py", line 651, in predict_sliding_window_return_logits
    predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/nnUNet/nnunetv2/inference/predict_from_raw_data.py", line 609, in _internal_predict_sliding_window_return_logits
    raise e
  File "/root/nnUNet/nnunetv2/inference/predict_from_raw_data.py", line 592, in _internal_predict_sliding_window_return_logits
    prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)
   ...............................
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2017, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1574, in wrap_fake_exception
    return fn()
           ^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2018, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2150, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2132, in run_node
    return node.target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2013, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_prims_common/wrappers.py", line 141, in _fn
    result = fn(**bound.arguments)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_refs/__init__.py", line 2832, in cat
    return prims.cat(filtered, dim).clone(memory_format=memory_format)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_library/fake_impl.py", line 93, in meta_kernel
    return fake_impl_holder.kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_library/utils.py", line 20, in __call__
    return self.func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/library.py", line 1151, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_library/custom_ops.py", line 614, in fake_impl
    return self._abstract_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/_prims/__init__.py", line 1963, in _cat_meta
    torch._check(
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/__init__.py", line 1564, in _check
    _check_with(RuntimeError, cond, message)
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/torch/__init__.py", line 1546, in _check_with
    raise error_type(message_evaluated)
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method cat of type object at 0x7f27acf27240>(*((FakeTensor(..., device='cuda:0', size=(1, 512, 10, 10), dtype=torch.float16), FakeTensor(..., device='cuda:0', size=(1, 512, 9, 10), dtype=torch.float16)), 1), **{}):
Sizes of tensors must match except in dimension 1. Expected 10 but got 9 for tensor number 1 in the list

from user code:
   File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/dynamic_network_architectures/architectures/unet.py", line 62, in forward
    return self.decoder(skips)
  File "/opt/miniconda3/envs/nnUNet/lib/python3.11/site-packages/dynamic_network_architectures/building_blocks/unet_decoder.py", line 110, in forward
    x = torch.cat((x, skips[-(s+2)]), 1)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants