Variable series length support for foundation models#3125
Conversation
…gth inputs and pre-pad smaller inputs during inference for foundation models
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #3125 +/- ##
==========================================
- Coverage 96.54% 96.15% -0.39%
==========================================
Files 160 160
Lines 17261 17361 +100
==========================================
+ Hits 16664 16693 +29
- Misses 597 668 +71 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hi @Kurokabe, thank you for inviting me to review this PR. Adding variable-length support to foundation models would be a great feature, and I truly appreciate your efforts in exploring the design and implementation. To be completely honest, I found it a bit difficult to fully evaluate the PR at its current stage, mainly for two reasons:
What do you think? |
Hi @dennisbader and @daidahao ,
Here is my draft PR to support variable-length fine-tuning and inference on foundation models.
The main changes are:
VariableLengthTorchTrainingDataset(new class intraining_dataset.py): aTorchTrainingDatasetsubclass that accepts series shorter than input_chunk_length by left-padding the past window with NaN. This allowsfit_from_dataset()to handle heterogeneous datasets without requiring per-window input_chunk_length tuning or silently dropping short series. Covariates and sample weights are intentionally not supported for now.FoundationModel._build_inference_datasetoverride (new method infoundation_model.py): transparently left-pads short series with NaN before passing them toSequentialTorchInferenceDataset, so thatpredict()works on short series without any manual pre-processing from callers. The padding logic mirrors whatVariableLengthTorchTrainingDatasetdoes during training.Note that for now, only inference has been tested end-to-end.
dev_fev_tasks_mini_validation.ipynbandfev_tasks_mini.yamlare development artifacts I've included in case you want to reproduce the validation runs, they will be removed before merging.One thing I can't fully explain: the notebook compares three approaches. Step 1 (adaptive input_chunk_length, window-by-window) produces different results than steps 2 and 3. Step 2 uses a fixed input_chunk_length=32 with manual NaN pre-padding before fit(), processed window-by-window. Step 3 uses VariableLengthTorchTrainingDataset with the same fixed input_chunk_length=32 in a single pass over all series. Steps 2 and 3 match each other exactly, which validates that VariableLengthTorchTrainingDataset is equivalent to manual pre-padding. But I can't explain why step 1 produces different outputs, since the only difference is the input_chunk_length value used per window, it's likely a context-length effect rather than a batching artefact, but I'm not certain. Do you have any insight on this?
One idea I had for a potential follow-up: instead of a dedicated VariableLengthTorchTrainingDataset, we could relax the short-series validation in ShiftedTorchTrainingDataset (currently a hard error in _get_end_of_output_idx) and handle the NaN padding in a
collate_fnpassed to the DataLoader.Let me know what you think :)