-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
relay:opsrc/relay/opsrc/relay/optopipython/tvm/topipython/tvm/topitype:rfc-trackingRFC progress tracking. Ref: https://github.com/apache/tvm-rfcsRFC progress tracking. Ref: https://github.com/apache/tvm-rfcs
Description
This issue tracks work on supporting mixed precision within TVM.
RFC: apache/tvm-rfcs#6
- Write initial mixed precision pass ([Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass #8069)
- Audit schedules for support of accumulation dtypes (should be fixed with [TIR] cast disparate floating point types for binary ops #8517)
- Benchmark accuracy loss on select models (https://gist.github.com/masahi/e4c611694e3dfd307a8b6bba45eb1658, thanks @masahi)
- Ensure tuning models also does not reveal errors for mixed precision
- Extend support for CUDA backend ([AMP] CUDA support for mixed precision pass #8294)
- Extend support for Vulkan backend ([AMP] Vulkan Support for Mixed Precision Pass #8295)
- Extend pass to support Algebraic Data Types in Relay
- Ensure bfloat16 works with pass on select GPUs ([TIR, Relay] improve bfloat16 support #10112)
Edge case ops:
- Extend sorting implemented in C++ to allow FP16 types https://github.com/apache/tvm/blob/main/src/runtime/contrib/sort/sort.cc#L436 (thanks @masahi)
- Add accumulation datatypes (e.g.
out_dtypes
) for following types of ops:- Pooling
- Sum
- Mean
Other discussions:
- Creating default ALLOW, FOLLOW, NEVER lists for ops
- Move certain pooling / average operations that are global into NEVER list (or use FP32 accumulation).
- Write a tutorial
Tasks which may help:
- Get dedup extraneous casts pass completed and refactor existing pass ([Pass] Simplify consecutive casts in Relay #8081)
- Write pass which can change function signatures so it is fp16 inputs, fp16 outputs. E.g. like Add pass to fold type transformations into function signature #9357
Benchmarking improvements from pass: https://docs.google.com/spreadsheets/d/12lgyfuHaRS-X4uG-1iQOV8oAuPpuVAbspcmkOSPRFHQ/edit?usp=sharing
jcf94, zxybazh and CircleSpin
Metadata
Metadata
Assignees
Labels
relay:opsrc/relay/opsrc/relay/optopipython/tvm/topipython/tvm/topitype:rfc-trackingRFC progress tracking. Ref: https://github.com/apache/tvm-rfcsRFC progress tracking. Ref: https://github.com/apache/tvm-rfcs