-
Notifications
You must be signed in to change notification settings - Fork 0
Chore/metal marlin sync #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Implemented `test_trellis_integration.py` to validate the complete inference pipeline of the Trellis model, including model loading, memory usage, numerical stability, token decoding, and performance metrics. - Created `test_trellis_moe_accuracy.py` to verify the accuracy of the fast Metal MoE kernel against a slow Python implementation, focusing on dequantization accuracy and output correlation. - Introduced `test_trellis_nan.py` to reproduce and validate the NaN bug in expansion layers of the Trellis model, ensuring outputs are finite and comparing expansion vs contraction layers. - Added `test_zmq.py` to demonstrate basic ZeroMQ communication between a coordinator and an agent using asynchronous messaging.
- Enable fast MoE kernel by fixing NaN issues in the implementation. - Update weight handling to conform to GEMM conventions in Trellis models. - Introduce new scripts for checking buffer sizes, isolating NaN issues, and comparing outputs between slow and fast paths. - Add comprehensive tests to validate configurations and ensure stability across various dimensions.
|
@codex Review |
|
Codex review is not enabled for this repo. Please contact the admins of this repo to enable Codex. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR synchronizes Trellis/Metal kernel behavior with updated memory-optimized weight handling, and adds extensive diagnostics, benchmarks, tests, and documentation around MoE/Trellis correctness, stability (NaNs/overflow), and serving metrics.
Changes:
- Update Trellis GEMM/dequant Metal kernels and dispatch/debug tooling to investigate/fix indexing/sign/layout issues and improve observability.
- Add memory-optimization plumbing (CPU→Metal buffers), MoE/Trellis metrics, and NaN-guard utilities.
- Add new tests, benchmarks, and docs covering Trellis accuracy/regressions, MLA KV cache changes, MoE kernel behavior, and operational guidance.
Reviewed changes
Copilot reviewed 52 out of 54 changed files in this pull request and generated 27 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_trellis_nan.py | Adds a model-backed regression test for TrellisLinear NaN reproduction (expansion layer). |
| tests/test_trellis_moe_accuracy.py | Adds fast-vs-slow MoE accuracy/correlation/determinism tests using mock weights. |
| tests/test_trellis_linear_accuracy.py | Adds parametric regression tests comparing fused TrellisLinear vs explicit dequant+matmul. |
| tests/test_trellis_kv_cache.py | Updates KV cache tests to use qk_rope_head_dim instead of num_kv_heads * head_dim. |
| tests/test_trellis_integration.py | Adds slow integration tests for full-model load/infer/memory/TPS sanity checks. |
| tests/test_rmsnorm_overflow.py | Adds RMSNorm FP16 overflow regression tests. |
| tests/test_moe_buffer_creation.py | Adds tests for creating cached MoE Metal buffers from CPU tensors. |
| tests/test_memory_optimization.py | Adds tests around eager buffer creation and memory behavior. |
| tests/test_cpu_to_metal_buffer.py | Adds tests for CPU tensor → Metal buffer conversion utilities. |
| src/gemm_trellis.metal | Adjusts Trellis fused GEMM indexing/sign handling and tile element addressing. |
| src/dequant_trellis.metal | Adjusts standalone Trellis dequant indexing within tiles. |
| scripts/test_dequant.py | Adds a local script to compare fused forward vs explicit dequant+matmul. |
| scripts/profile_memory.py | Adds an RSS-based memory profiler for forward pass + MoE detailed profiling. |
| scripts/isolate_nan_config.py | Adds a script to sweep MoE configs to isolate NaN triggers. |
| scripts/debug_tile_boundaries.py | Adds a script to probe NaNs around tile-boundary shapes. |
| scripts/debug_moe_scaling.py | Adds a scaling/debug harness for MoE kernel configurations and timeouts. |
| scripts/debug_moe_buffers.py | Adds scripts to validate buffer layouts/contiguity and trace NaN sources. |
| scripts/compare_moe_paths.py | Adds a script to compare slow vs fast MoE outputs on real dims. |
| scripts/check_expert_nan.py | Adds a script to test individual experts for NaNs. |
| scripts/check_buffer_sizes.py | Adds a script to sanity-check stacked buffer shapes/offsets. |
| metal_marlin/trellis/testing.py | Adds mock Trellis model/layer factories for unit tests and CI usage. |
| metal_marlin/trellis/nan_guard.py | Adds runtime NaN detection + fallback orchestration with statistics tracking. |
| metal_marlin/trellis/metrics.py | Adds Prometheus-style metrics primitives and MoE metrics export. |
| metal_marlin/trellis/loader.py | Clarifies packed weight loading comments and error strings. |
| metal_marlin/trellis/linear.py | Adds Metal buffer caching hooks, group_size plumbing, and input dtype coercion. |
| metal_marlin/trellis/kv_cache.py | Updates MLA cache dimensioning to qk_rope_head_dim and adds seq_len property. |
| metal_marlin/trellis/dispatch.py | Adds METAL_DEBUG-gated pre/post dispatch logging for decode GEMM. |
| metal_marlin/trellis/attention.py | Implements TrellisKVCache compressed-KV caching path alongside legacy KV cache path. |
| metal_marlin/transformer.py | Updates RMSNorm to use FP32 accumulation and cast behavior to avoid FP16 overflow. |
| metal_marlin/serving/server.py | Extends /metrics endpoint to include MoE metrics output. |
| metal_marlin/metal_dispatch.py | Improves dtype handling for MPS tensor → Metal buffer conversion; adds CPU/numpy buffer helpers. |
| docs/trellis_kernels.md | Adds detailed Trellis kernel architecture documentation. |
| docs/moe_architecture.md | Adds detailed MoE kernel architecture and buffer layout documentation. |
| docs/memory_optimization.md | Documents CPU→Metal memory optimization approach and caveats. |
| docs/kernel_profile.md | Adds guide for profiling Metal kernels and interpreting results. |
| docs/deployment_checklist.md | Adds production deployment checklist (tests/monitoring/rollback). |
| docs/audits/gemm_trellis_correctness_audit.md | Adds an audit write-up for fused Trellis GEMM correctness. |
| docs/audits/dequant_audit.md | Adds an audit write-up for Trellis dequant shaders and layout discrepancies. |
| docs/async_dispatch.md | Adds investigation results for async MPS dispatch/pipelining. |
| benchmarks/bench_moe_multicase.py | Adds a multi-case MoE benchmark runner with RSS checks. |
| benchmarks/bench_moe_kernel.py | Simplifies benchmark script to time end-to-end forward across context lengths. |
| STATUS.md | Updates status and adds MoE kernel performance/status reporting. |
| README.md | Updates serving docs and CLI options (batched/paged attention, health endpoint, metrics). |
| .gitignore | Expands ignores for local dev/debug scripts and artifacts. |
Comments suppressed due to low confidence (1)
metal_marlin/trellis/kv_cache.py:190
- These comments still reference the old
num_kv_heads * head_dimlayout for k_pe, but the cache now storesqk_rope_head_diminstead. Please update the comment(s) here to reflect that k_pe has shape [batch, seq_len, qk_rope_head_dim].
# Split compressed_kv into c_kv and k_pe components
# c_kv: [batch, seq_len, kv_lora_rank]
# k_pe: [batch, seq_len, num_kv_heads * head_dim]
c_kv_new = compressed_kv[..., : self.kv_lora_rank]
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…yntax Fix uv dependency installation syntax in cd.yaml workflow
This pull request introduces several documentation updates and a new benchmarking script, primarily focused on MoE (Mixture-of-Experts) kernel performance, async dispatch investigation, and a detailed audit of Trellis dequantization shaders. It also enhances the OpenAI-compatible server documentation and status reporting, reflecting new features and improved test coverage.
Documentation improvements and audits:
Benchmarking and performance:
benchmarks/bench_moe_multicase.py, a new benchmarking script for evaluating MoE kernel performance across multiple batch sizes and use cases, including memory usage checks.OpenAI-compatible server enhancements:
README.mdto document new server features, including paged attention, batching options, KV cache tuning, and a health check endpoint. Also added descriptions of new CLI options for these features.STATUS.mdto reflect increased test coverage (from 28 to 30 tests), explicit paged attention testing, and new endpoints. [1] [2] [3]MoE kernel status and performance reporting:
STATUS.mddetailing the status and performance of Metal MoE kernels, including a breakdown of fast and slow paths and memory usage.These changes collectively improve documentation clarity, testing transparency, and provide new tools for performance analysis and debugging.