Skip to content

Refactor InferenceWrapper and subclasses for better readability and gradient handling #108

@thomas-mullor

Description

@thomas-mullor

Description:
The InferenceWrapper class and its subclasses could benefit from a refactoring aimed at improving readability and simplifying gradient computation over large input batches. It would also provide batched computation of gradient on large inputs (actually not well managed)

Goals:
Refactor the simple call_model method to directly handle a wide range of options.
Other externally called methods should delegate to call_model directly instead of following the current call chain.

TODO:

  1. Transform get_logits into some batch_call_model method that only takes iterables (generators) as input :
    Input parameters:
  • input: (unchanged)
  • target: Optional. If None, return full input logits; otherwise, return targeted_logits.
  • return_grad: Boolean, default False. If True, perform backward pass; otherwise, use a no_grad context.
  • ...additional options as needed
    Responsibilities:
  • Handle context management for different tasks (classification, generation, etc.)
  • Manage input slicing and batching internally instead of in get_logits
    The version of get_logits that takes mapping as input is simply removed and this kind of usage won't be suported anymore.
    Code division : It would be better if the current implementation was split into several specific, clearly named functions and the batching management was more like pseudo code.
  1. Update other methods to delegate to batch_call_model:
  • get_logits: simple call to batch_call_model with target=None
  • get_targeted_logits: call batch_call_model with the appropriate target
  • get_gradients: call batch_call_model with return_grad=True
  • get_inputs_to_explain_and_targets : For now, this method that is specific to generation context could be defined on his side, without dealing with large batching. We assume that users would not want to call generation tasks on a large amount of sentences
  1. Code simplification:
    Remove input-based dispatching; only support iterable inputs
    Centralize sanity checks and input transformations (reshape_inputs, process_inputs, etc.)

Expected Outcome:
A cleaner, more maintainable InferenceWrapper with a flexible entry point (call_model) for all types of input processing, task handling, and gradient computation.

Collaboration & Feedback:
This issue is open for proposals. Developers and users are welcome to suggest improvements or related ideas in the comments. Contributions and discussions are encouraged to shape the best possible design.

Metadata

Metadata

Labels

refactoringNeeds some redefinition of code structure, no new feature

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions