fix: correct sampler chain order, penalty params, and KV eviction#196
fix: correct sampler chain order, penalty params, and KV eviction#196LopezNuance wants to merge 2 commits into
Conversation
Five root causes of degenerate repetitive output identified by side-by-side comparison of Shimmy's generation loop against Ollama's: 1. Penalty parameter order was wrong: frequency_penalty was passed as penalty_repeat, presence_penalty as penalty_freq, and repeat_penalty as penalty_present. With freq=0.3 landing in the repeat slot, repeat_penalty < 1.0 actively BOOSTED repeated tokens (~3.3x). 2. Sampler chain applied penalties AFTER top-k/top-p filtering, so penalties only operated on ~40 surviving candidates instead of the full vocabulary. Ollama applies penalties first. 3. greedy() final sampler offered no stochastic escape from repetition loops. Replaced with dist(seed) matching Ollama's behavior. 4. KV cache eviction started from position 0, destroying the system prompt tokens. Now preserves num_prompt_tokens (matching Ollama's numKeep approach) and only evicts generated output. 5. Added DRY (Decaying Repetition) sampler for long-range pattern suppression, matching Ollama's default sampler chain. Also adds frequency_penalty and presence_penalty fields to GenOptions and ChatCompletionRequest, wiring them through the OpenAI-compatible API endpoint. Signed-off-by: Scott Johnson <rsjohnnson@users.noreply.github.com> Signed-off-by: scott <scott@procyon.here>
Complements the upstream sampler fixes in this PR: the sampler changes prevent degeneration at the source, while the detector halts residual pathological loops that still slip through. Detector design: - trigram_diversity(): unique / total character trigrams over the last 400 chars. Healthy text > 0.25, degenerate loops < 0.10. Clean separation against P0 calibration data, zero false positives. Checked every 32 tokens after 600 chars of output. - floor_char_boundary(): O(1) walk-back to the nearest valid UTF-8 char boundary. Prevents panics when the detection window boundary lands inside a multi-byte char (math symbols like ᵢ, ×, σ). Without this the Mutex<LlamaContext> would be poisoned and every subsequent request would return 502 forever. Replaces an earlier windowed-cosine approach that missed loops starting partway through the output. Adds 6 unit tests covering ASCII, 2/3/4-byte UTF-8, edge cases, and the exact multi-byte panic scenario observed in P0 calibration runs. Signed-off-by: LopezNuance <m6gmjmjwfw@liamekaens.com>
|
Added a second commit ( The sampler changes here prevent degeneration at the source. The detector catches any residual pathological loops that still slip through by measuring character-trigram diversity over the last 400 chars of output every 32 tokens. Below a 0.20 diversity threshold generation halts cleanly. Also adds 6 unit tests (ASCII / 2/3/4-byte UTF-8 / edge cases / the exact panic scenario observed in calibration runs) all pass. Happy to split this into a follow-up PR if you prefer a single-concern review. |
Summary
Fixes degenerate repetitive output (repetitive tails,
***...***patterns) on long generations by correcting five root causes identified through a side-by-side comparison of Shimmy's generation loop against Ollama's.Note: This PR supersedes #195 and related earlier PRs that were iterating on symptoms while the root causes were being diagnosed. Apologies for the noise.
Root Cause Analysis
Methodology
Identical prompt, identical model (exaone-deep:2.4b), identical parameters (freq=0.3, pres=0.1, temp=0.3), 4 concurrent instances:
This definitively proved the degeneration is Shimmy-specific, not a model capacity issue. The investigation compared Shimmy's
src/engine/llama.rsagainst Ollama'srunner/llamarunner/runner.go,llama/llama.cpp/common/sampling.cpp, andrunner/llamarunner/cache.go.RC1: Penalty parameter order swapped (CRITICAL)
The
LlamaSampler::penalties()Rust binding maps directly tollama_sampler_init_penalties(penalty_last_n, penalty_repeat, penalty_freq, penalty_present). But Shimmy was calling:With ACMT sending
frequency_penalty=0.3, presence_penalty=0.1:penalty_repeatreceived 0.3 → for positive logits,logit / 0.3 = 3.3x AMPLIFICATIONof repeated tokenspenalty_freqreceived 0.1 (weakened from intended 0.3)penalty_presentreceived 1.1 (the defaultrepeat_penalty, way too aggressive for presence)The repeat penalty was literally encouraging repetition. After thousands of tokens, this positive feedback loop collapses the output into repetitive garbage.
RC2: Sampler chain order inverted
Shimmy was applying penalties to ~40 candidates that survived top-k/top-p filtering. Ollama applies penalties to the full vocabulary first, then filters. This made Shimmy's penalties far less effective at suppressing repetition.
RC3: Greedy vs probabilistic final sampling
Shimmy used
greedy()(always picks highest-probability token). Ollama usesdist(seed)(samples from the distribution). With greedy, once a repeated token survives penalties and has the highest logit, there is zero probability of escape. Replaced withdist(seed)matching Ollama.RC4: KV cache eviction destroyed system prompt
Ollama's
ShiftCacheSlotpreservesnumKeeptokens at the start (system prompt + BOS). Shimmy was evicting from position 0, destroying all instruction context after the first eviction. This explains why models produced garbage after long generations even when other bugs were fixed — they lost their system prompt.RC5: No DRY sampler
Ollama includes DRY (Decaying Repetition penaltY) as a default sampler for long-range repetition detection. Added to Shimmy's chain with the same default parameters (
multiplier=0.8, base=1.75, allowed_length=2, penalty_last_n=64, seq_breakers=["\n", ":", "\"", "*"]).Additional Changes
frequency_penaltyandpresence_penaltyfields toGenOptions(defaulting to 0.0, matching OpenAI API semantics)ChatCompletionRequestand the OpenAI-compatible handlerctx.clear_kv_cache()before each generation to prevent stale KV contamination across requestsTest Plan
cargo fmt -- --checkpassescargo test --features llama --lib— 327 passed, 0 failedissue_101_all_fixes_integrated,issue_129_precompiled_gpu_support) that also fail on cleanmain