Skip to content

Dice loss module causes TypeError: forward() got an unexpected keyword argument 'ignore_index' in v1.1.0 #3172

@ff98li

Description

@ff98li

Checklist

  1. I have searched related issues but cannot get the expected help. ☑️
  1. The bug has not been fixed in the latest version. ☑️

Describe the bug
Attempted to train SegFormer using Dice loss + CE loss; the following error occurred after updating to mmsegmentation v1.1.0:

Traceback (most recent call last):
  File "/home/fei/pkgs/mmsegmentation/tools/train.py", line 104, in <module>
    main()
  File "/home/fei/pkgs/mmsegmentation/tools/train.py", line 100, in main
    runner.train()
  File "/home/fei/software/mambaforge/envs/openmmlab/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1701, in train
    model = self.train_loop.run()  # type: ignore
  File "/home/fei/software/mambaforge/envs/openmmlab/lib/python3.8/site-packages/mmengine/runner/loops.py", line 278, in run
    self.run_iter(data_batch)
  File "/home/fei/software/mambaforge/envs/openmmlab/lib/python3.8/site-packages/mmengine/runner/loops.py", line 301, in run_iter
    outputs = self.runner.model.train_step(
  File "/home/fei/software/mambaforge/envs/openmmlab/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step
    losses = self._run_forward(data, mode='loss')  # type: ignore
  File "/home/fei/software/mambaforge/envs/openmmlab/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 326, in _run_forward
    results = self(**data, mode=mode)
  File "/home/fei/software/mambaforge/envs/openmmlab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fei/pkgs/mmsegmentation/mmseg/models/segmentors/base.py", line 94, in forward
    return self.loss(inputs, data_samples)
  File "/home/fei/pkgs/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py", line 176, in loss
    loss_decode = self._decode_head_forward_train(x, data_samples)
  File "/home/fei/pkgs/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py", line 137, in _decode_head_forward_train
    loss_decode = self.decode_head.loss(inputs, data_samples,
  File "/home/fei/pkgs/mmsegmentation/mmseg/models/decode_heads/decode_head.py", line 262, in loss
    losses = self.loss_by_feat(seg_logits, batch_data_samples)
  File "/home/fei/pkgs/mmsegmentation/mmseg/models/decode_heads/decode_head.py", line 324, in loss_by_feat
    loss[loss_decode.loss_name] = loss_decode(
  File "/home/fei/software/mambaforge/envs/openmmlab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'ignore_index'

Reproduction

  1. What command or script did you run?

Added Dice loss to my config for SegFormer as below:

  model = dict(
   ## Don't use data_preprocessor in predefined config
   data_preprocessor = data_preprocessor,
   pretrained = None,
   decode_head = dict(
       num_classes = 13, 
       loss_decode = [
           dict(
               type = "CrossEntropyLoss",
               loss_name = "loss_ce",
               loss_weight = 2.0
           ),
           dict(
               type = 'DiceLoss',
               loss_name = 'loss_dice',
               loss_weight = 1.0
           )
       ]
   )
)

I believe this issue is directly related to the Dice loss as after removing the Dice loss the error no longer occurs.

  1. Did you make any modifications on the code or config? Did you understand what you have modified?

Yes, as described above. It used to work with dev-1.* before I updated mmseg to the main branch version.

  1. What dataset did you use?
    🤐 Custom datasets, but used to work with dev-1.*
    Environment
sys.platform: linux
Python: 3.8.16 | packaged by conda-forge | (default, Feb  1 2023, 16:01:55) [GCC 11.3.0]
CUDA available: True
numpy_random_seed: 2147483648
GPU 0,1: NVIDIA RTX A6000
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 11.6, V11.6.124
GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
PyTorch: 2.0.0
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.1-Product Build 20220311 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.7
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.5
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.7, CUDNN_VERSION=8.5.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

TorchVision: 0.15.0
OpenCV: 4.7.0
MMEngine: 0.7.0
MMSegmentation: 1.1.0+0079076

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions