Skip to content

non-integer condition allowed for image generation with Unet #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

psteinb
Copy link

@psteinb psteinb commented May 8, 2025

  • added more tests for running forward functions of models
  • test which expands Unet to take floating point conditions (or anything else)
  • edited Unet to accept custom embedding net (this allows for non-integer labels to condition the generation)
  • added demo notebook to demonstrate behavior

What does this PR do?

Fixes #163

No breaking changes, all tests pass.

Before submitting

  • Did you make sure title is self-explanatory and the description concisely explains the PR?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you test your PR locally with pytest command?
  • Did you run pre-commit hooks with pre-commit run -a command?

Did you have fun?

Make sure you had fun coding 🙃

Summary by Sourcery

Enhance Unet model to support non-integer and custom embedding conditions for image generation

New Features:

  • Allow Unet to accept custom embedding networks for conditioning
  • Support non-integer labels for model conditioning

Bug Fixes:

  • Fix label embedding for non-class conditional networks

Enhancements:

  • Modify Unet forward method to be more flexible with label embeddings
  • Remove strict type checking for conditional labels

Tests:

  • Add test for Unet initialization
  • Add test for MLP model
  • Add test for conditional model with non-integer labels

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link

sourcery-ai bot commented May 8, 2025

Reviewer's Guide

This pull request enables UNetModel to use non-integer conditional inputs by introducing a customizable embedding_net. This network processes the conditional input y in the model's forward pass, removing the previous restriction that tied y to integer-based num_classes. The changes are validated by new tests, including a specific test for float conditions using a mock embedding layer.

File-Level Changes

Change Details Files
Introduced customizable conditional input embedding in UNetModel.
  • Added embedding_net parameter to UNetModel and UNetModelWrapper for user-defined condition embedding.
  • Modified UNetModel.forward to process conditional input y through self.label_emb (instantiated from embedding_net), allowing diverse condition types.
  • Removed an assertion that previously coupled the presence of conditional input y exclusively with num_classes.
torchcfm/models/unet/unet.py
Expanded test coverage for model conditioning and forward pass integrity.
  • Added test_conditional_model_without_integer_labels using a mock_embedding class to verify float-based conditioning.
  • Strengthened model testing with new/updated tests for UNetModel and MLP forward operations and initialization.
tests/test_models.py

Assessment against linked issues

Issue Objective Addressed Explanation
#163 Enable the UnetModel to accept non-integer (e.g., floating point) conditions for image generation.
#163 Allow users to provide custom embedding networks for conditioning the generation process.
#163 Provide a demonstration of the new functionality with a notebook.

Possibly linked issues


Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We encountered an error and are unable to review this PR. We have been notified and are working to fix it.

You can try again by commenting this pull request with @sourcery-ai review, or contact us for help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Non-integer condition to generate with
1 participant