Skip to content

Fix Gemma 4 KV-shared layers creating unused projections#1158

Merged
angeloskath merged 4 commits into
ml-explore:mainfrom
glyphVault:fix/gemma4-kv-shared-layers
Apr 21, 2026
Merged

Fix Gemma 4 KV-shared layers creating unused projections#1158
angeloskath merged 4 commits into
ml-explore:mainfrom
glyphVault:fix/gemma4-kv-shared-layers

Conversation

@glyphVault
Copy link
Copy Markdown
Contributor

Summary

  • Gemma 4 E4B/E2B models share KV projections across later layers (num_kv_shared_layers). The Attention class was creating k_proj, v_proj, k_norm, and v_norm for all layers, but shared layers never use them — the forward pass routes KV from earlier layers via shared_kv.
  • This caused load_weights(strict=True) to fail for any Gemma 4 model saved through transformers (save_pretrained), since transformers correctly omits these weights for shared layers. This affects all derivative models: fine-tunes, merges, abliterations, etc.
  • Skip creating k_proj/v_proj/k_norm/v_norm for KV-shared layers, matching the transformers implementation.
  • Add a defensive ValueError if a shared layer somehow receives no shared_kv at runtime.

Test plan

  • All 8 existing gemma4 tests pass
  • New test test_gemma4_kv_shared_layers_omit_kv_projections verifies shared layers don't create KV modules
  • Verified forward pass produces identical top-5 logits vs transformers on OBLITERATUS/gemma-4-E4B-it-OBLITERATED
  • Verified cached generation produces coherent output
  • Formatted with black

🤖 Generated with Claude Code

glyphVault and others added 4 commits April 15, 2026 15:43
Gemma 4 E4B/E2B models share KV projections across later layers
(controlled by num_kv_shared_layers). The Attention class was creating
k_proj, v_proj, k_norm, and v_norm for all layers, but shared layers
never use them — the forward pass routes KV from earlier layers via
shared_kv.

This caused strict weight loading to fail for any Gemma 4 model saved
through transformers (fine-tunes, merges, abliterations), since
transformers correctly omits these weights for shared layers.

- Skip creating k_proj/v_proj/k_norm/v_norm for KV-shared layers
- Add defensive ValueError if a shared layer receives no shared_kv
- Add test verifying shared layers omit KV projections

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks great catch.

@angeloskath angeloskath merged commit 4f5cbd2 into ml-explore:main Apr 21, 2026
2 checks passed
ericcurtin pushed a commit to vllm-project/vllm-metal that referenced this pull request Apr 27, 2026
fix #299 

* Passed Qwen3 deterministic test. No regression
* Passed Gemma4-E4B e2e smoke test, the model loaded successfully and
output normal token (not gibberish)

future plan: need to keep watching on
ml-explore/mlx-lm#1158 . There might be more
upstream bug fixings in the future.

---------

Signed-off-by: Ranran Haoran Zhang <haorzhang@ebay.com>
Signed-off-by: ran <hzz5361@psu.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants