Fix sharded rms norm in MiniMax M2.5#898
Merged
Merged
Conversation
awni
reviewed
Feb 16, 2026
| return f"{self.weight.shape[0] * self.group.size()}, eps={self.eps}" | ||
|
|
||
| def __call__(self, x): | ||
| return sharded_rms_norm(self.group)(x, self["weight"], self.eps) |
Member
There was a problem hiding this comment.
I think that will make a new function every time so each call needs to be recompiled.
Member
Author
There was a problem hiding this comment.
Oops yeah. Imported the lru cache but forgot to add it.
awni
reviewed
Feb 16, 2026
Comment on lines
+41
to
+43
| norm2 = x.square().sum(-1, keepdims=True) | ||
| norm2 = mx.distributed.all_sum(norm2, group=group) | ||
| norm = mx.rsqrt(norm2 / (x.shape[-1] * group.size()) + eps) |
Member
There was a problem hiding this comment.
Oh also x should probably be up cast prior to the sum for parity with mx.fast.rms_norm
358cdca to
d3060bd
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Minimax M2.5 sharding is currently broken because the qk norm is done on the whole vector.
This PR fixes it but introduces 2 extra communications per attention layer which pretty much destroys decoding perf improvement. I have a more complicated version that does one communication and overlaps it with some computation which is better but not by much, we may need something more fundamental here.
Otoh this is now correct and the prompt processing scales quite nicely at 2.7x speedup across 4 nodes at 8k tokens. Adding to that the fact that the KV cache is also 1/4th the size it still makes sense to shard it for long agentic tasks or using multiple subagents.