-
Notifications
You must be signed in to change notification settings - Fork 2
Description
Description:
The InferenceWrapper class and its subclasses could benefit from a refactoring aimed at improving readability and simplifying gradient computation over large input batches. It would also provide batched computation of gradient on large inputs (actually not well managed)
Goals:
Refactor the simple call_model method to directly handle a wide range of options.
Other externally called methods should delegate to call_model directly instead of following the current call chain.
TODO:
- Transform
get_logitsinto somebatch_call_modelmethod that only takes iterables (generators) as input :
Input parameters:
input: (unchanged)target: Optional. If None, return full input logits; otherwise, return targeted_logits.return_grad: Boolean, default False. If True, perform backward pass; otherwise, use a no_grad context.- ...additional options as needed
Responsibilities: - Handle context management for different tasks (classification, generation, etc.)
- Manage input slicing and batching internally instead of in get_logits
The version ofget_logitsthat takes mapping as input is simply removed and this kind of usage won't be suported anymore.
Code division : It would be better if the current implementation was split into several specific, clearly named functions and the batching management was more like pseudo code.
- Update other methods to delegate to
batch_call_model:
get_logits: simple call tobatch_call_modelwithtarget=Noneget_targeted_logits: callbatch_call_modelwith the appropriate targetget_gradients: callbatch_call_modelwithreturn_grad=Trueget_inputs_to_explain_and_targets: For now, this method that is specific to generation context could be defined on his side, without dealing with large batching. We assume that users would not want to call generation tasks on a large amount of sentences
- Code simplification:
Remove input-based dispatching; only support iterable inputs
Centralize sanity checks and input transformations (reshape_inputs, process_inputs, etc.)
Expected Outcome:
A cleaner, more maintainable InferenceWrapper with a flexible entry point (call_model) for all types of input processing, task handling, and gradient computation.
Collaboration & Feedback:
This issue is open for proposals. Developers and users are welcome to suggest improvements or related ideas in the comments. Contributions and discussions are encouraged to shape the best possible design.