Skip to content

Conversation

invoker-bot
Copy link
Contributor

@invoker-bot invoker-bot commented Sep 27, 2023

Fixes #3055

Description:

Now we can define our custom model_fn in create_supervised_trainer and create_supervised_evaluator.

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@github-actions github-actions bot added the module: engine Engine module label Sep 27, 2023
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 27, 2023

@invoker-bot thanks for the PR, please add also a test for this feature into https://github.com/pytorch/ignite/blob/master/tests/ignite/engine/test_create_supervised.py

@invoker-bot
Copy link
Contributor Author

@invoker-bot thanks for the PR, please add also a test for this feature into https://github.com/pytorch/ignite/blob/master/tests/ignite/engine/test_create_supervised.py

I have made these changes, please check it.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update @invoker-bot
Few improvements to add and it can be good to be merged


loss[0] = mse_loss(_y_pred, _y).item()

# loss[0] = mse_loss(model(_x), _y).item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove commented code

theta[0] -= accumulation[0] / gradient_accumulation_steps
assert pytest.approx(model.fc.weight.data[0, 0].item(), abs=1.0e-5) == theta[0]
assert pytest.approx(trainer.state.output[-1], abs=1e-5) == loss[0]
print("loss:", loss[0], "theta:", theta[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this print

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 29, 2023

@invoker-bot please run code style formatting script to fix CI issues:

bash ./tests/run_code_style.sh install
bash ./tests/run_code_style.sh fmt

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 3, 2023

@invoker-bot can you please address the comment such that the PR can be merged and will be included to the next release?

@invoker-bot
Copy link
Contributor Author

@invoker-bot can you please address the comment such that the PR can be merged and will be included to the next release?

I have fixed this issue now, please check it.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 3, 2023

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @invoker-bot

@vfdev-5 vfdev-5 merged commit b8751f2 into pytorch:master Oct 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: engine Engine module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a feature "support multi params to call forward method"?
2 participants