[BugFix] Add strict_shape parameter to QValueModule for action shape enforcement#3593
Conversation
…enforcement When using Categorical specs with singleton dimensions, argmax drops the trailing dim causing action shape to not match the spec. Add strict_shape parameter to QValueModule and QValueActor: - None (default): FutureWarning on shape mismatch - 'auto': automatically reshape action to match spec - True: raise RuntimeError on mismatch - False: silently allow mismatch Fixes pytorch#3059 Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
|
Hi @Lidang-Jiang! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
vmoens
left a comment
There was a problem hiding this comment.
Thanks for this!
The linter needs to be fixed, and I would mention the version where the change of behavior will occur (0.14)
| warnings.warn( | ||
| f"Action shape {action.shape} does not match expected shape {target_shape} " | ||
| f"(per-sample spec shape: {per_sample_shape}). " | ||
| f"In a future version, this will raise an error. " |
There was a problem hiding this comment.
2 versions from now, which is in v0.14
There was a problem hiding this comment.
Fixed in e6fcc31. Updated the warning message to specify v0.14 as the deprecation version. Also fixed linter formatting issues.
- Fix linter formatting (line length) in actors.py and test_actors.py - Specify deprecation version (v0.14) in FutureWarning message Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
Summary
Fixes #3059 — QValueActor now respects
action_specshape for singleton dimensions.When using
Categoricalspecs with singleton dimensions (e.g.,shape=(1, 1)),argmax(dim=-1)drops the trailing dimension, causing the action shape to not match the spec. This adds astrict_shapeparameter toQValueModuleandQValueActorfollowing the approach suggested by @vmoens in #3059.strict_shapeparameterNone(default)FutureWarningon shape mismatch (backward compatible)"auto"TrueRuntimeErroron shape mismatchFalseBefore
After
Test results (55 passed)
Test plan