Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

upcoming feature tracker #187

Closed
Closed
@vkuzo

Description

@vkuzo

configurability

  • [done] support delayed vs dynamic scaling type, configurable separately for activations/weights/gradients
  • [planned] support rowwise/blockwise scaling granularity, configurable separately for each gemm
  • [planned] configure settings for each of the three gemms in linear fwd/bwd separately
  • [planned] support more fine grained configuration of how to apply Float8Linear to individual modules
  • [planned] inference support (see [RFC] Float8 Inference #314)

performance

  • [done] torch._scaled_mm support for per-tensor scaled float8 gemm
  • [in progress] torch._scaled_mm support for rowwise scaled float8 gemm
    • [done] eager mode support
    • [planned] torch.compile support, backed by triton/cutlass
  • [in progress] optimize torch.compile performance for float8 scaling/casting kernels

distributed

  • [done] integrate with TP/SP via DTensor APIs
  • [done] integrate with FSDP1 with 16-bit all-gather
  • [done] integrate with FSDP2 with 16-bit or 8-bit all-gather with dynamic scaling for weights
    • performance optimizations are ongoing
  • [in progress] integrate with FSDP2 with 16-bit or 8-bit all-gather with delayed scaling for weights
    • POC is done, performance optimizations are ongoing
  • [planned] verify integration with PP

other

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions