Skip to content

Fix sharded rms norm in MiniMax M2.5#898

Merged
angeloskath merged 3 commits into
mainfrom
fix-m2.5-dist
Feb 17, 2026
Merged

Fix sharded rms norm in MiniMax M2.5#898
angeloskath merged 3 commits into
mainfrom
fix-m2.5-dist

Conversation

@angeloskath
Copy link
Copy Markdown
Member

Minimax M2.5 sharding is currently broken because the qk norm is done on the whole vector.

This PR fixes it but introduces 2 extra communications per attention layer which pretty much destroys decoding perf improvement. I have a more complicated version that does one communication and overlaps it with some computation which is better but not by much, we may need something more fundamental here.

Otoh this is now correct and the prompt processing scales quite nicely at 2.7x speedup across 4 nodes at 8k tokens. Adding to that the fact that the KV cache is also 1/4th the size it still makes sense to shard it for long agentic tasks or using multiple subagents.

@angeloskath angeloskath requested a review from awni February 16, 2026 04:12
Comment thread mlx_lm/models/minimax.py
return f"{self.weight.shape[0] * self.group.size()}, eps={self.eps}"

def __call__(self, x):
return sharded_rms_norm(self.group)(x, self["weight"], self.eps)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that will make a new function every time so each call needs to be recompiled.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops yeah. Imported the lru cache but forgot to add it.

Comment thread mlx_lm/models/minimax.py Outdated
Comment on lines +41 to +43
norm2 = x.square().sum(-1, keepdims=True)
norm2 = mx.distributed.all_sum(norm2, group=group)
norm = mx.rsqrt(norm2 / (x.shape[-1] * group.size()) + eps)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh also x should probably be up cast prior to the sum for parity with mx.fast.rms_norm

Copy link
Copy Markdown
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing that!

@angeloskath angeloskath merged commit d7b91e8 into main Feb 17, 2026
2 checks passed
@angeloskath angeloskath deleted the fix-m2.5-dist branch February 17, 2026 01:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants