Nemotron super support#992
Conversation
3f1de93 to
406d9df
Compare
|
Tested alongside PR #988 on Mac Studio M2 Ultra (128GB) with Combined branch (this PR + #988) loads and runs correctly:
Note: This PR requires #988 (SSM precision fix) to produce coherent output. Without it, the Metal decode kernel degenerates after ~15 tokens. Minor note: the |
|
Here are unit tests for the LatentMoE additions. 11 tests, all passing on mlx 0.31.1 / Python 3.12. Tests cover:
Test file: tests/test_nemotron_latentmoe.py"""Tests for Nemotron-H LatentMoE support (PR #992).
Tests the additions to nemotron_h.py:
- ModelArgs: moe_latent_size, layers_block_type normalization, time_step_limit defaults
- NemotronHMoE: latent projection forward pass
- Model.sanitize: MTP weight stripping
"""
import unittest
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.nemotron_h import Model, ModelArgs, NemotronHMoE
class TestModelArgsLatentMoE(unittest.TestCase):
"""Test ModelArgs parsing for Nemotron Super config fields."""
def _base_args(self, **overrides):
cfg = {
"model_type": "nemotron_h",
"vocab_size": 1000,
"hidden_size": 128,
"intermediate_size": 64,
"num_hidden_layers": 4,
"max_position_embeddings": 1000,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"attention_bias": False,
"mamba_num_heads": 4,
"mamba_head_dim": 32,
"mamba_proj_bias": False,
"ssm_state_size": 32,
"conv_kernel": 4,
"n_groups": 2,
"time_step_min": 0.001,
"mlp_bias": False,
"layer_norm_epsilon": 1e-5,
"use_bias": False,
"use_conv_bias": True,
"hybrid_override_pattern": ["M", "E", "*", "E"],
"n_routed_experts": 8,
"num_experts_per_tok": 2,
"moe_intermediate_size": 64,
}
cfg.update(overrides)
return ModelArgs(**cfg)
def test_moe_latent_size_parsed(self):
args = self._base_args(moe_latent_size=32)
self.assertEqual(args.moe_latent_size, 32)
def test_moe_latent_size_none_by_default(self):
args = self._base_args()
self.assertIsNone(args.moe_latent_size)
def test_layers_block_type_normalization(self):
args = self._base_args(
hybrid_override_pattern=None,
layers_block_type=["mamba", "moe", "attention", "moe"],
)
self.assertEqual(args.hybrid_override_pattern, ["M", "E", "*", "E"])
self.assertEqual(args.num_hidden_layers, 4)
def test_hybrid_override_pattern_string(self):
args = self._base_args(hybrid_override_pattern="ME*E")
self.assertEqual(len(args.hybrid_override_pattern), 4)
self.assertEqual(list(args.hybrid_override_pattern), ["M", "E", "*", "E"])
def test_time_step_limit_no_upper_bound(self):
args = self._base_args(time_step_min=0.001)
self.assertEqual(args.time_step_limit[0], 0.001)
self.assertEqual(args.time_step_limit[1], float("inf"))
def test_time_step_limit_explicit_overrides(self):
args = self._base_args(time_step_limit=(0.01, 0.5), time_step_min=0.001)
self.assertEqual(args.time_step_limit, (0.01, 0.5))
class TestNemotronHMoELatent(unittest.TestCase):
def _make_config(self, moe_latent_size=None):
return self._base_args(moe_latent_size=moe_latent_size)
def _base_args(self, **overrides):
cfg = {
"model_type": "nemotron_h",
"vocab_size": 1000,
"hidden_size": 64,
"intermediate_size": 32,
"num_hidden_layers": 2,
"max_position_embeddings": 512,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"attention_bias": False,
"mamba_num_heads": 4,
"mamba_head_dim": 16,
"mamba_proj_bias": False,
"ssm_state_size": 16,
"conv_kernel": 4,
"n_groups": 2,
"time_step_min": 0.001,
"mlp_bias": False,
"layer_norm_epsilon": 1e-5,
"use_bias": False,
"use_conv_bias": True,
"hybrid_override_pattern": ["E", "E"],
"n_routed_experts": 4,
"num_experts_per_tok": 2,
"moe_intermediate_size": 32,
"n_group": 1,
"topk_group": 1,
"routed_scaling_factor": 1.0,
"norm_topk_prob": True,
}
cfg.update(overrides)
return ModelArgs(**cfg)
def test_latent_projection_shapes(self):
config = self._make_config(moe_latent_size=16)
moe = NemotronHMoE(config)
mx.eval(moe.parameters())
x = mx.random.normal((1, 1, 64))
y = moe(x)
mx.eval(y)
self.assertEqual(y.shape, (1, 1, 64))
def test_no_latent_projection(self):
config = self._make_config(moe_latent_size=None)
moe = NemotronHMoE(config)
mx.eval(moe.parameters())
x = mx.random.normal((1, 1, 64))
y = moe(x)
mx.eval(y)
self.assertEqual(y.shape, (1, 1, 64))
def test_latent_projection_has_layers(self):
config = self._make_config(moe_latent_size=16)
moe = NemotronHMoE(config)
self.assertTrue(hasattr(moe, "fc1_latent_proj"))
self.assertTrue(hasattr(moe, "fc2_latent_proj"))
self.assertEqual(moe.fc1_latent_proj.weight.shape, (16, 64))
self.assertEqual(moe.fc2_latent_proj.weight.shape, (64, 16))
def test_shared_expert_gets_original_input(self):
config = self._make_config(moe_latent_size=16)
config.n_shared_experts = 1
config.moe_shared_expert_intermediate_size = 32
moe = NemotronHMoE(config)
mx.eval(moe.parameters())
x = mx.random.normal((1, 1, 64))
y = moe(x)
mx.eval(y)
self.assertEqual(y.shape, (1, 1, 64))
class TestSanitizeMTP(unittest.TestCase):
def test_mtp_weights_stripped(self):
config = ModelArgs(
model_type="nemotron_h", vocab_size=100, hidden_size=64,
intermediate_size=32, num_hidden_layers=2, max_position_embeddings=256,
num_attention_heads=4, num_key_value_heads=2, attention_bias=False,
mamba_num_heads=4, mamba_head_dim=16, mamba_proj_bias=False,
ssm_state_size=16, conv_kernel=4, n_groups=2, time_step_min=0.001,
mlp_bias=False, layer_norm_epsilon=1e-5, use_bias=False,
use_conv_bias=True, hybrid_override_pattern=["*", "M"],
)
model = Model(config)
weights = {
"model.embed_tokens.weight": mx.zeros((100, 64)),
"model.layers.0.norm.weight": mx.zeros((64,)),
"mtp.layers.0.weight": mx.zeros((64, 64)),
"mtp.head.weight": mx.zeros((100, 64)),
}
sanitized = model.sanitize(weights)
self.assertNotIn("mtp.layers.0.weight", sanitized)
self.assertNotIn("mtp.head.weight", sanitized)
self.assertIn("model.embed_tokens.weight", sanitized) |
|
Hi @Thump604 I don't see the same thing, #988 is not needed to produce coherent output. I am trying to find a way to gauge whether keeping the state in fp32 has any effect at all. Contrary to #997 there is a speed regression by moving the SSM to fp32 because the batched path is not via a custom kernel. I will likely still move it to fp32 in this PR just to be more compatible to other implementations. |
|
Thanks for consolidating into this PR. On the fp32 state question — I'll put together a comparison with before/after output from the 4.5-bit Nemotron quant on M2 Ultra. The degradation we observed was during autoregressive generation: output became incoherent after ~15 tokens with bf16 state, coherent with fp32. It's possible the effect is more pronounced with aggressive quantization (4.5-bit) than with higher-precision weights — I'll capture concrete samples and share them here. Glad to hear you're planning to adopt fp32 regardless for cross-implementation compatibility. |
|
I did try 4.5, 5, 6.5 and 8.5 bpw didn't really make a noticeable difference so I would love to have a reproducible issue. |
|
Here's the bf16 vs fp32 state comparison you asked about. Tested on M2 Ultra 128GB with the 5-bit Nemotron-3-Super-120B-A12B quant, Prompt: "Explain step by step how to solve this equation: 3x² + 7x - 20 = 0. Show your complete working, including the discriminant calculation and both roots." fp32 state (3/3 coherent)All three trials immediately begin solving the equation with correct math: Trial 1: Directly applies quadratic formula → correct roots (5/3 and -4) bf16 state (1/3 failed, 1/3 noisy, 1/3 coherent)Trial 1 — failure: Model never solves the equation. Instead generates additional prompt-like text: Then meta-reasons about the format rather than answering. Trial 3 — noisy: Blurts answer first, then generates user-like instructions ("I need to see the step-by-step solution clearly. Use the quadratic formula or factoring method.") before re-solving. Trial 2 — correct: Coherent step-by-step solution. AnalysisThe bf16 failure mode is "prompt regeneration" — the model generates text that looks like additional user instructions rather than its own response, as if it loses track of the conversation boundary. This is consistent with SSM recurrence state precision loss: small errors compound across decode steps, causing the model's internal state to drift enough that it confuses the user/assistant boundary. The effect is intermittent (1-2 out of 3 trials), not deterministic, and manifests as contextual confusion rather than gibberish. It may be more pronounced with aggressive quantization (5-bit weights amplify the bf16 state precision gap). fp32 state shows no instances of this behavior across all trials. |
|
Well in all honesty I don't see a difference at all. Let me know if you can tell which is which without looking at the inference speed from the two runs below 🤷♂️ Run 1Run 2Here is the impact on speed. M3 Ultra prompt 8k 4 bits |
nastya236
left a comment
There was a problem hiding this comment.
Thanks for adding this!!
As title. Will add more as the port progresses.