diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c784ce..1aa1673 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,32 @@ # Changelog +## [2.1.0] - 2026-04-20 + +### Features +- **LM Studio provider** (closes #2): first-class support for [LM Studio](https://lmstudio.ai)'s OpenAI-compatible local server. Run benchmarks with `gauntlet run --model lmstudio/`; `gauntlet discover` lists currently-loaded models. Host configurable via `LMSTUDIO_HOST` env var, `gauntlet config --lmstudio-host`, or the default `http://localhost:1234`. Metadata (family, parameter size, quantization) inferred from the model ID. +- **Cloud ChatClient wiring**: `gauntlet run --model openai/`, `anthropic/`, and `google/` now work directly (previously `NotImplementedError`). Enables leaderboard baselines for GPT-4o, Claude, and Gemini — typical full-sweep cost is under $5, and Gemini has a free tier. +- **MCP server improvements**: + - Self-driving tool instructions so MCP clients (Claude Code, Gemini CLI, Cursor) can run the full suite without custom user prompts — includes explicit "do NOT shell out" directives. + - Auto-detects the client app via `Context.session.client_params.clientInfo`, with a clear separation between client app and model identifier. + - New `gauntlet_status(session_id)` tool replays the current probe on demand for resumability. + +### Fixes +- **Temporal Reasoning probe**: prompt previously said "Reply with ONLY the name" despite the correct answer being neither Alice nor Bob. Some models (notably Gemini 2.5 Pro) looped for minutes trying to resolve the bind before their own CLI aborted. Prompt now lists `'Alice' | 'Bob' | 'Neither'` explicitly. Verify function unchanged (still accepts equal/both/tie/neither). +- **Leaderboard provider mis-attribution**: `collect_fingerprint(r.model, "ollama")` hardcoded the provider when submitting results, so non-Ollama runs appeared on the leaderboard as Ollama. Now derived via `detect_provider()`. Affected `gauntlet quick` and the TUI path. + +### Safety +- **Agent-invocation guard on `gauntlet run`**: when stdin/stdout aren't TTYs (the tell for MCP-client subprocess spawns), refuse to benchmark local models (Ollama / LM Studio / llama.cpp) unless `GAUNTLET_ALLOW_LOCAL=1` is set. Prevents MCP agents from accidentally loading large local models and overloading the user's machine. Cloud providers and interactive humans are unaffected. + +### Polish +- Error messages, `_auto_select_models()`, and the interactive setup now include LM Studio alongside Ollama — no more "Is Ollama running?" when LM Studio is loaded. +- Host resolution honors the config file for Ollama and LM Studio (env > file > default); persistent `gauntlet config --ollama-host` and `--lmstudio-host` flags now actually take effect. +- README: new LM Studio and Cloud Baselines sections, updated provider filter tables. + +### Tests +- 12 new LM Studio tests: host resolution precedence (env > config > default), spec parsing, factory wiring, metadata inference across 5 model-id patterns. + +--- + ## [2.0.3] - 2026-04-17 ### Fixes diff --git a/README.md b/README.md index bf9efbe..8d3c087 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@

- version + version

Gauntlet

@@ -15,6 +15,7 @@ PerplexityLeaderboardDashboard • + LM Studiollama.cppTaxonomyScoring • @@ -212,7 +213,7 @@ Gauntlet can. Every test submission includes anonymous hardware metadata: | CPU architecture | arm64, x86_64 | | RAM | 8GB, 16GB, 32GB, 64GB | | OS | macOS, Linux, Windows | -| Provider | Ollama, OpenAI, Anthropic, Google | +| Provider | Ollama, LM Studio, llama.cpp, OpenAI, Anthropic, Google | Filter the leaderboard by any combination to see how models compare on comparable hardware configurations. @@ -245,7 +246,7 @@ Public read-only endpoints at `https://gauntlet.basaltlabs.app` for building too |---|---| | `gpu_class` | apple_silicon, nvidia, amd, none | | `quantization` | Q4, Q8, fp16 | -| `provider` | ollama, openai, anthropic | +| `provider` | ollama, lmstudio, llamacpp, openai, anthropic, google | | `os_platform` | darwin, linux, windows | | `source` | cli, tui, dashboard, mcp | | `exclude_source` | mcp (default for community dashboard) | @@ -496,12 +497,53 @@ pip install gauntlet-cli[all-providers] # All cloud providers | Provider | Configuration | Cost | |---|---|---| | [Ollama](https://ollama.com) (local) | `ollama pull qwen3.5:4b` | Free | +| [LM Studio](https://lmstudio.ai) (local) | Load a model, then start Developer > Local Server | Free | | [llama.cpp](https://github.com/ggml-org/llama.cpp) (local) | `llama-server -m model.gguf` | Free | | OpenAI API | `export OPENAI_API_KEY=sk-...` | Pay-per-use | | Anthropic API | `export ANTHROPIC_API_KEY=sk-ant-...` | Pay-per-use | | Google AI API | `export GOOGLE_API_KEY=AI...` | Pay-per-use | -Ollama and llama.cpp run models locally with zero external dependency. Cloud providers are optional and can be combined with local models. +Ollama, LM Studio, and llama.cpp run models locally with zero external dependency. Cloud providers are optional and can be combined with local models. + +### LM Studio + +Gauntlet supports [LM Studio](https://lmstudio.ai) via its OpenAI-compatible local server. Load a model in LM Studio, start the server under **Developer > Local Server**, then use the `lmstudio:` prefix: + +```bash +# Default host: http://localhost:1234 +gauntlet discover # lists currently-loaded models +gauntlet run --model lmstudio/llama-3.2-8b-q4_K_M + +# Custom port (LM Studio lets users change it in-app) +export LMSTUDIO_HOST=http://localhost:4321 +gauntlet run --model lmstudio/qwen-7b + +# Or persist it +gauntlet config --lmstudio-host=http://localhost:4321 +``` + +### Cloud Baselines + +Gauntlet can run the suite directly against OpenAI, Anthropic, and Google Gemini APIs so leaderboard entries for frontier models sit alongside local runs on comparable axes. Export an API key and use the provider prefix: + +```bash +# Google Gemini (free tier available at https://aistudio.google.com/apikey) +export GOOGLE_API_KEY=AI... +gauntlet run --model google/gemini-2.5-flash +gauntlet run --model google/gemini-2.5-pro + +# OpenAI +export OPENAI_API_KEY=sk-... +gauntlet run --model openai/gpt-4o-mini +gauntlet run --model openai/gpt-4o + +# Anthropic Claude +export ANTHROPIC_API_KEY=sk-ant-... +gauntlet run --model anthropic/claude-haiku-4-5 +gauntlet run --model anthropic/claude-sonnet-4-6 +``` + +A full frontier sweep (6 models across 3 providers) typically costs under $5. Gemini Flash is free on the API's free tier. ### llama.cpp diff --git a/gauntlet/__init__.py b/gauntlet/__init__.py index 6e8270a..7442d8d 100644 --- a/gauntlet/__init__.py +++ b/gauntlet/__init__.py @@ -1,5 +1,5 @@ """Gauntlet - Behavioral reliability under pressure.""" -__version__ = "2.0.3" +__version__ = "2.1.0" __all__ = ["__version__"] diff --git a/gauntlet/cli/app.py b/gauntlet/cli/app.py index 86e1d02..79795a7 100644 --- a/gauntlet/cli/app.py +++ b/gauntlet/cli/app.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import os from typing import Optional import typer @@ -53,7 +54,7 @@ def _default(ctx: typer.Context) -> None: def run( model: list[str] = typer.Option( None, "--model", "-m", - help="Model to test (e.g. ollama/qwen2.5:14b). Can specify multiple.", + help="Model to test (e.g. ollama/qwen2.5:14b, lmstudio/llama-3.2-8b, openai/gpt-4o). Can specify multiple.", ), profile: Optional[str] = typer.Option( None, "--profile", "-p", @@ -158,9 +159,29 @@ def run( console.print(f"[dim]Profiles: {', '.join(available_profiles())}[/dim]") raise typer.Exit(0) + # Detect whether we're being invoked by an agent (non-interactive shell) + # vs. a human at a terminal. When an MCP client (Gemini CLI, Claude Code, + # Cursor, etc.) spawns `gauntlet run`, stdin/stdout aren't TTYs. We use + # this to hard-block local Ollama runs that would overload the user's + # machine without their awareness. + import sys as _sys + _is_agent = not (_sys.stdin.isatty() and _sys.stdout.isatty()) + _agent_override = os.environ.get("GAUNTLET_ALLOW_LOCAL") == "1" + # Parse model specs if not model: - print_error("No models specified. Use --model ollama/qwen2.5:14b") + print_error("No models specified.") + console.print() + console.print("[bold]Benchmarking yourself (the AI reading this)?[/bold]") + console.print(" -> STOP. Do NOT run this shell command. Use the") + console.print(" [cyan]gauntlet_run[/cyan] MCP tool from the gauntlet MCP server instead.") + console.print(" This CLI benchmarks OTHER local models — it does not score you.") + console.print() + console.print("[bold]Benchmarking a local Ollama model?[/bold]") + console.print(" -> Pass [cyan]--model ollama/[/cyan] (e.g. ollama/qwen2.5:14b)") + console.print() + console.print("[bold]Benchmarking a cloud API directly?[/bold]") + console.print(" -> Pass [cyan]--model openai/gpt-4o[/cyan] or [cyan]google/gemini-2.5-flash[/cyan] etc.") raise typer.Exit(1) model_specs = [] @@ -173,6 +194,45 @@ def run( name = m model_specs.append((provider, name)) + # Agent-invoked + local models (Ollama / LM Studio / llama.cpp) = footgun. + # Refuse hard unless the human explicitly unlocked with + # GAUNTLET_ALLOW_LOCAL=1. This prevents MCP clients that shell out to + # "gauntlet run" from accidentally loading large local models and tanking + # the user's machine. + _LOCAL_PROVIDERS = {"ollama", "lmstudio", "llamacpp"} + has_local = any(p in _LOCAL_PROVIDERS for (p, _) in model_specs) + if _is_agent and has_local and not _agent_override: + print_error( + "Refusing to benchmark LOCAL models from a non-interactive shell." + ) + console.print() + console.print( + "[bold]This invocation looks like it came from an AI agent " + "(MCP client / subprocess).[/bold]" + ) + console.print( + "Loading local models (Ollama / LM Studio / llama.cpp) here can overload the user's machine." + ) + console.print() + console.print("If you are the AI reading this:") + console.print(" -> Do NOT retry this command. Use the [cyan]gauntlet_run[/cyan] MCP tool.") + console.print(" The MCP tool benchmarks YOU (the AI); this CLI does not.") + console.print() + console.print( + "If you are a human and you really meant to benchmark a local model " + "non-interactively, set [cyan]GAUNTLET_ALLOW_LOCAL=1[/cyan] and retry." + ) + raise typer.Exit(2) + + # Friendly banner for interactive local benchmarks (human user). + if has_local and not _is_agent: + console.print( + "[yellow]Notice:[/yellow] benchmarking LOCAL model(s). " + "If you meant to benchmark a cloud LLM, use [cyan]--model openai/..." + "[/cyan] or [cyan]google/...[/cyan] — or the MCP server for self-scoring." + ) + console.print() + # Module filter module_names = None if module: @@ -410,8 +470,9 @@ def compare( models = args[:-1] if not models: - print_error("No models found. Is Ollama running? Do you have models installed?") - console.print("[dim]Run: ollama pull gemma4:e2b[/dim]") + print_error("No models found. Is Ollama or LM Studio running with a model loaded?") + console.print("[dim]Ollama: ollama pull gemma4:e2b[/dim]") + console.print("[dim]LM Studio: load a model, then Developer > Local Server > Start[/dim]") raise typer.Exit(1) if dashboard: @@ -534,10 +595,17 @@ def on_token(model: str, text: str, metrics): async def _auto_select_models(max_models: int = 2) -> list[str]: - """Auto-detect installed models and pick the best ones to compare.""" - from gauntlet.core.discover import discover_ollama + """Auto-detect installed models and pick the best ones to compare. + + Checks Ollama first for backward compatibility; falls back to LM Studio + if Ollama has nothing loaded. Users running only LM Studio still get a + sensible auto-pick. + """ + from gauntlet.core.discover import discover_ollama, discover_lmstudio models = await discover_ollama() + if not models: + models = await discover_lmstudio() if not models: return [] @@ -566,18 +634,19 @@ async def _auto_select_models(max_models: int = 2) -> list[str]: async def _interactive_setup() -> tuple[list[str], str]: """Interactive model selection and prompt input.""" - from gauntlet.core.discover import discover_ollama + from gauntlet.core.discover import discover_ollama, discover_lmstudio from rich.prompt import Prompt print_header() with console.status("[bold cyan]Finding installed models..."): available = await discover_ollama() + available += await discover_lmstudio() if not available: - print_error("No models found. Is Ollama running?") - console.print("[dim]Install: https://ollama.com/download[/dim]") - console.print("[dim]Then: ollama pull gemma4:e2b[/dim]") + print_error("No models found. Is Ollama or LM Studio running with a model loaded?") + console.print("[dim]Ollama: https://ollama.com/download then ollama pull gemma4:e2b[/dim]") + console.print("[dim]LM Studio: https://lmstudio.ai then load a model and start the local server[/dim]") raise typer.Exit(1) # Show available models with numbers @@ -744,8 +813,9 @@ def benchmark( console.print("[dim]Auto-detecting installed models...[/dim]") detected = asyncio.run(_auto_select_models(max_models=5)) if not detected: - print_error("No models found. Is Ollama running?") - console.print("[dim]Run: ollama pull gemma4:e2b[/dim]") + print_error("No models found. Is Ollama or LM Studio running with a model loaded?") + console.print("[dim]Ollama: ollama pull gemma4:e2b[/dim]") + console.print("[dim]LM Studio: load a model, then Developer > Local Server > Start[/dim]") raise typer.Exit(1) models = detected @@ -792,9 +862,15 @@ def benchmark( from gauntlet.core.submit import submit_result def _submit_benchmarks(): + from gauntlet.core.config import detect_provider for r in results: try: - fp = collect_fingerprint(r.model, "ollama") + # Derive the real provider from the model spec so fingerprint + # metadata lines up on the community leaderboard. Defaults to + # ollama when the spec is bare (e.g. "qwen2.5:14b"), matching + # existing behaviour for plain Ollama names. + detected_provider, _ = detect_provider(r.model) + fp = collect_fingerprint(r.model, detected_provider) hw, rt, mc = fp.to_storage_dicts() # Scale scores from 0-1 to 0-100 for the community API raw_cats = getattr(r, "category_scores", {}) @@ -1094,14 +1170,16 @@ def leaderboard() -> None: @app.command() def config( ollama_host: Optional[str] = typer.Option(None, "--ollama-host", help="Set Ollama API host"), + lmstudio_host: Optional[str] = typer.Option(None, "--lmstudio-host", help="Set LM Studio local server host (e.g. http://localhost:1234)"), show_config: bool = typer.Option(False, "--show", help="Show current config"), ) -> None: """View or modify Gauntlet configuration.""" - from gauntlet.core.config import load_config, save_config, get_ollama_host + from gauntlet.core.config import load_config, save_config, get_ollama_host, get_lmstudio_host if show_config: cfg = load_config() console.print(f"Ollama host: {get_ollama_host()}") + console.print(f"LM Studio host: {get_lmstudio_host()}") for k, v in cfg.items(): console.print(f"{k}: {v}") return @@ -1112,6 +1190,13 @@ def config( save_config(cfg) console.print(f"[green]Ollama host set to: {ollama_host}[/green]") + if lmstudio_host: + cfg = load_config() + cfg["lmstudio_host"] = lmstudio_host + save_config(cfg) + console.print(f"[green]LM Studio host set to: {lmstudio_host}[/green]") + console.print("[dim]Note: LMSTUDIO_HOST env var takes precedence if set.[/dim]") + @app.command() def mcp( diff --git a/gauntlet/cli/tui.py b/gauntlet/cli/tui.py index 84eba1d..9181c2e 100644 --- a/gauntlet/cli/tui.py +++ b/gauntlet/cli/tui.py @@ -1234,10 +1234,15 @@ def on_probe(idx, total, name, passed): f" [{idx}/{total}] {icon} {name}\n", ) + # Derive the real provider from the model spec — users may pass + # lmstudio:* or llamacpp:* into the TUI, not just plain Ollama names. + from gauntlet.core.config import detect_provider + provider, clean_model = detect_provider(model_name) + try: results, score, _trust = asyncio.run(run_gauntlet( - model_name=model_name, - provider="ollama", + model_name=clean_model, + provider=provider, profile="assistant", quick=self._quick, config={"on_probe_complete": on_probe}, diff --git a/gauntlet/core/client.py b/gauntlet/core/client.py index 1054ac3..820a927 100644 --- a/gauntlet/core/client.py +++ b/gauntlet/core/client.py @@ -17,7 +17,15 @@ import httpx -from gauntlet.core.config import get_ollama_host, get_llamacpp_host +from gauntlet.core.config import ( + get_api_key, + get_llamacpp_host, + get_lmstudio_host, + get_ollama_host, + PROVIDER_ANTHROPIC, + PROVIDER_GOOGLE, + PROVIDER_OPENAI, +) # Cache thinking-model detection per model name (survives across ChatClient instances) _thinking_model_cache: dict[str, bool] = {} @@ -64,6 +72,11 @@ def __post_init__(self): if not self._host: if self.provider == "llamacpp": self._host = get_llamacpp_host() + elif self.provider == "lmstudio": + self._host = get_lmstudio_host() + elif self.provider in (PROVIDER_OPENAI, PROVIDER_ANTHROPIC, PROVIDER_GOOGLE): + # Cloud providers use hardcoded endpoints; _host is unused. + self._host = "" else: self._host = get_ollama_host() @@ -123,6 +136,14 @@ async def _complete(self, temperature: float | None = None) -> str: return await self._ollama_chat(temp) elif self.provider == "llamacpp": return await self._llamacpp_chat(temp) + elif self.provider == "lmstudio": + return await self._lmstudio_chat(temp) + elif self.provider == PROVIDER_OPENAI: + return await self._openai_chat(temp) + elif self.provider == PROVIDER_ANTHROPIC: + return await self._anthropic_chat(temp) + elif self.provider == PROVIDER_GOOGLE: + return await self._google_chat(temp) else: raise NotImplementedError(f"Provider {self.provider} not yet supported for ChatClient") @@ -225,6 +246,63 @@ async def _ollama_chat(self, temperature: float) -> str: return content + async def _lmstudio_chat(self, temperature: float) -> str: + """Call LM Studio's local server via OpenAI-compatible /v1/chat/completions. + + LM Studio exposes an OpenAI-compatible API on port 1234 by default. + Override with the LMSTUDIO_HOST env var (some users run on custom + ports). The model_name maps to whichever model is currently loaded + in LM Studio — use `gauntlet discover` to list loaded models. + """ + url = f"{self._host}/v1/chat/completions" + + payload = { + "model": self.model_name, + "messages": [ + {"role": m.role, "content": m.content} + for m in self._history + ], + "temperature": temperature, + "max_tokens": self.max_tokens, + "stream": False, + } + + timeout = httpx.Timeout( + connect=30.0, + read=self.timeout_s, + write=30.0, + pool=30.0, + ) + + try: + async with httpx.AsyncClient(timeout=timeout) as http: + resp = await http.post(url, json=payload) + resp.raise_for_status() + data = resp.json() + except httpx.ReadTimeout: + raise TimeoutError( + f"LM Studio did not respond within {self.timeout_s:.0f}s. " + f"Model may be too large for available memory, or still loading." + ) + except httpx.ConnectError: + raise ConnectionError( + f"Cannot connect to LM Studio at {self._host}. " + f"Open LM Studio, go to Developer > Local Server, and start the server. " + f"Override the host with LMSTUDIO_HOST=http://localhost:." + ) + + choices = data.get("choices", []) + if not choices: + raise ValueError("LM Studio returned no choices") + + content = choices[0].get("message", {}).get("content", "") + + usage = data.get("usage", {}) + self._total_tokens += usage.get("completion_tokens", 0) + + self._history.append(ChatMessage(role="assistant", content=content)) + return content + async def _llamacpp_chat(self, temperature: float) -> str: """Call llama.cpp server via OpenAI-compatible /v1/chat/completions. @@ -285,6 +363,196 @@ async def _llamacpp_chat(self, temperature: float) -> str: self._history.append(ChatMessage(role="assistant", content=content)) return content + async def _openai_chat(self, temperature: float) -> str: + """Call OpenAI's Chat Completions API. + + Uses OPENAI_API_KEY from the environment. Supports any OpenAI + model (gpt-4o, gpt-4o-mini, o1, o3, etc.). System messages pass + through as-is in the messages array. + """ + api_key = get_api_key(PROVIDER_OPENAI) + if not api_key: + raise RuntimeError( + "OPENAI_API_KEY is not set. Export it before running cloud benchmarks." + ) + + url = "https://api.openai.com/v1/chat/completions" + + payload: dict = { + "model": self.model_name, + "messages": [ + {"role": m.role, "content": m.content} + for m in self._history + ], + "stream": False, + } + # o-series reasoning models reject temperature/max_tokens in favor of + # max_completion_tokens and default temp=1. + is_reasoning = self.model_name.startswith(("o1", "o3", "o4")) + if is_reasoning: + payload["max_completion_tokens"] = self.max_tokens + else: + payload["temperature"] = temperature + payload["max_tokens"] = self.max_tokens + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + timeout = httpx.Timeout(connect=30.0, read=self.timeout_s, write=30.0, pool=30.0) + + try: + async with httpx.AsyncClient(timeout=timeout) as http: + resp = await http.post(url, json=payload, headers=headers) + resp.raise_for_status() + data = resp.json() + except httpx.ReadTimeout: + raise TimeoutError( + f"OpenAI did not respond within {self.timeout_s:.0f}s." + ) + + choices = data.get("choices", []) + if not choices: + raise ValueError("OpenAI returned no choices") + content = choices[0].get("message", {}).get("content", "") or "" + + usage = data.get("usage", {}) + self._total_tokens += usage.get("completion_tokens", 0) + + self._history.append(ChatMessage(role="assistant", content=content)) + return content + + async def _anthropic_chat(self, temperature: float) -> str: + """Call Anthropic's Messages API. + + Uses ANTHROPIC_API_KEY from the environment. Extracts system + messages into Anthropic's dedicated `system` field (Anthropic + doesn't accept role='system' inside the messages array). + """ + api_key = get_api_key(PROVIDER_ANTHROPIC) + if not api_key: + raise RuntimeError( + "ANTHROPIC_API_KEY is not set. Export it before running cloud benchmarks." + ) + + url = "https://api.anthropic.com/v1/messages" + + # Anthropic wants system as a top-level string; merge all system turns. + system_parts = [m.content for m in self._history if m.role == "system"] + chat_messages = [ + {"role": m.role, "content": m.content} + for m in self._history + if m.role in ("user", "assistant") + ] + + payload: dict = { + "model": self.model_name, + "messages": chat_messages, + "max_tokens": self.max_tokens, + "temperature": temperature, + } + if system_parts: + payload["system"] = "\n\n".join(system_parts) + + headers = { + "x-api-key": api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + } + + timeout = httpx.Timeout(connect=30.0, read=self.timeout_s, write=30.0, pool=30.0) + + try: + async with httpx.AsyncClient(timeout=timeout) as http: + resp = await http.post(url, json=payload, headers=headers) + resp.raise_for_status() + data = resp.json() + except httpx.ReadTimeout: + raise TimeoutError( + f"Anthropic did not respond within {self.timeout_s:.0f}s." + ) + + # Concatenate text blocks (ignore tool-use blocks for probe answers) + parts = data.get("content", []) or [] + content = "".join(b.get("text", "") for b in parts if b.get("type") == "text") + + usage = data.get("usage", {}) + self._total_tokens += usage.get("output_tokens", 0) + + self._history.append(ChatMessage(role="assistant", content=content)) + return content + + async def _google_chat(self, temperature: float) -> str: + """Call Google's Generative Language API (Gemini). + + Uses GOOGLE_API_KEY from the environment. Gemini uses role='model' + instead of 'assistant' and puts system instructions in a dedicated + `system_instruction` field. + """ + api_key = get_api_key(PROVIDER_GOOGLE) + if not api_key: + raise RuntimeError( + "GOOGLE_API_KEY is not set. Export it before running cloud benchmarks." + ) + + url = ( + f"https://generativelanguage.googleapis.com/v1beta/" + f"models/{self.model_name}:generateContent?key={api_key}" + ) + + system_parts = [m.content for m in self._history if m.role == "system"] + contents = [] + for m in self._history: + if m.role == "system": + continue + # Gemini expects 'user' | 'model' (not 'assistant') + role = "model" if m.role == "assistant" else m.role + contents.append({"role": role, "parts": [{"text": m.content}]}) + + payload: dict = { + "contents": contents, + "generationConfig": { + "temperature": temperature, + "maxOutputTokens": self.max_tokens, + }, + } + if system_parts: + payload["system_instruction"] = { + "parts": [{"text": "\n\n".join(system_parts)}] + } + + timeout = httpx.Timeout(connect=30.0, read=self.timeout_s, write=30.0, pool=30.0) + + try: + async with httpx.AsyncClient(timeout=timeout) as http: + resp = await http.post(url, json=payload) + resp.raise_for_status() + data = resp.json() + except httpx.ReadTimeout: + raise TimeoutError( + f"Google Gemini did not respond within {self.timeout_s:.0f}s." + ) + + # candidates[0].content.parts[0].text — but parts can be multiple + candidates = data.get("candidates", []) + if not candidates: + # Safety block or empty response + block_reason = (data.get("promptFeedback") or {}).get("blockReason") + raise ValueError( + f"Gemini returned no candidates" + f"{' (blocked: ' + block_reason + ')' if block_reason else ''}" + ) + + parts = (candidates[0].get("content") or {}).get("parts", []) or [] + content = "".join(p.get("text", "") for p in parts) + + usage = data.get("usageMetadata", {}) + self._total_tokens += usage.get("candidatesTokenCount", 0) + + self._history.append(ChatMessage(role="assistant", content=content)) + return content + @property def history(self) -> list[ChatMessage]: """Get conversation history (read-only view).""" diff --git a/gauntlet/core/config.py b/gauntlet/core/config.py index 7a7d431..b3c3a60 100644 --- a/gauntlet/core/config.py +++ b/gauntlet/core/config.py @@ -36,10 +36,14 @@ def _resolve_gauntlet_dir() -> Path: PROVIDER_GOOGLE = "google" PROVIDER_OPENAI_COMPAT = "openai-compatible" PROVIDER_LLAMACPP = "llamacpp" +PROVIDER_LMSTUDIO = "lmstudio" # llama.cpp defaults DEFAULT_LLAMACPP_HOST = "http://localhost:8080" +# LM Studio defaults (LM Studio's built-in local server) +DEFAULT_LMSTUDIO_HOST = "http://localhost:1234" + @dataclass class ProviderConfig: @@ -57,14 +61,38 @@ def ensure_gauntlet_dir() -> Path: return GAUNTLET_DIR +def _host_from_config(env_var: str, config_key: str, default: str) -> str: + """Resolve a host with precedence: env var > config file > default.""" + env_val = os.environ.get(env_var) + if env_val: + return env_val + try: + cfg = load_config() + if cfg.get(config_key): + return cfg[config_key] + except Exception: + pass + return default + + def get_ollama_host() -> str: - """Get the Ollama API host from env or default.""" - return os.environ.get("OLLAMA_HOST", DEFAULT_OLLAMA_HOST) + """Get the Ollama API host (env > config > default).""" + return _host_from_config("OLLAMA_HOST", "ollama_host", DEFAULT_OLLAMA_HOST) def get_llamacpp_host() -> str: - """Get the llama.cpp server host from env or default.""" - return os.environ.get("LLAMACPP_HOST", DEFAULT_LLAMACPP_HOST) + """Get the llama.cpp server host (env > config > default).""" + return _host_from_config("LLAMACPP_HOST", "llamacpp_host", DEFAULT_LLAMACPP_HOST) + + +def get_lmstudio_host() -> str: + """Get the LM Studio local server host (env > config > default). + + LM Studio lets users change the server port inside the app. Set the + LMSTUDIO_HOST environment variable (e.g. http://localhost:4321) or + persist via `gauntlet config --lmstudio-host=http://localhost:4321`. + """ + return _host_from_config("LMSTUDIO_HOST", "lmstudio_host", DEFAULT_LMSTUDIO_HOST) def detect_provider(model_name: str) -> tuple[str, str]: @@ -76,13 +104,15 @@ def detect_provider(model_name: str) -> tuple[str, str]: "openai:gpt-4o" -> (openai, gpt-4o) "anthropic:claude-sonnet-4-20250514" -> (anthropic, claude-sonnet-4-20250514) "google:gemini-2.0-flash" -> (google, gemini-2.0-flash) + "lmstudio:llama-3.2-8b" -> (lmstudio, llama-3.2-8b) + "llamacpp:model" -> (llamacpp, model) "http://host:port/v1:model" -> (openai-compatible, model) with base_url """ if ":" in model_name: prefix, _, rest = model_name.partition(":") # Check for known providers - if prefix in (PROVIDER_OLLAMA, PROVIDER_OPENAI, PROVIDER_ANTHROPIC, PROVIDER_GOOGLE, PROVIDER_LLAMACPP): + if prefix in (PROVIDER_OLLAMA, PROVIDER_OPENAI, PROVIDER_ANTHROPIC, PROVIDER_GOOGLE, PROVIDER_LLAMACPP, PROVIDER_LMSTUDIO): return prefix, rest # Check for URL-based custom endpoint (openai-compatible) @@ -133,6 +163,10 @@ def resolve_model(model_spec: str) -> ProviderConfig: base_url, model_name = model_name.split("||", 1) elif provider == PROVIDER_OLLAMA: base_url = get_ollama_host() + elif provider == PROVIDER_LMSTUDIO: + base_url = get_lmstudio_host() + elif provider == PROVIDER_LLAMACPP: + base_url = get_llamacpp_host() api_key = get_api_key(provider) diff --git a/gauntlet/core/discover.py b/gauntlet/core/discover.py index f50de22..e2c5068 100644 --- a/gauntlet/core/discover.py +++ b/gauntlet/core/discover.py @@ -10,6 +10,7 @@ from gauntlet.core.providers.openai_provider import OpenAIProvider from gauntlet.core.providers.anthropic_provider import AnthropicProvider from gauntlet.core.providers.google_provider import GoogleProvider +from gauntlet.core.providers.lmstudio import LMStudioProvider def get_system_memory() -> dict: @@ -80,6 +81,11 @@ def spec(self) -> str: return self.name return f"{self.provider}:{self.name}" + @property + def is_local(self) -> bool: + """Whether this model runs on local hardware (not a cloud API).""" + return self.provider in ("ollama", "lmstudio", "llamacpp") + async def discover_ollama() -> list[DiscoveredModel]: """Discover models from local Ollama installation.""" @@ -106,6 +112,30 @@ async def discover_ollama() -> list[DiscoveredModel]: return models +async def discover_lmstudio() -> list[DiscoveredModel]: + """Discover currently-loaded models from LM Studio's local server. + + Returns an empty list if LM Studio isn't running or the server is off. + The list reflects only the models the user has loaded in LM Studio, + not every model on disk — loaded models are the ones runnable. + """ + lmstudio = LMStudioProvider() + if not await lmstudio.check_connection(): + return [] + + raw_models = await lmstudio.list_models() + models = [] + for m in raw_models: + models.append( + DiscoveredModel( + name=m["name"], + provider="lmstudio", + display_name=m["name"], + ) + ) + return models + + async def discover_openai() -> list[DiscoveredModel]: """Discover models from OpenAI (if API key is set).""" api_key = get_api_key(PROVIDER_OPENAI) @@ -171,6 +201,7 @@ async def discover_all() -> list[DiscoveredModel]: results = await asyncio.gather( discover_ollama(), + discover_lmstudio(), discover_openai(), discover_anthropic(), discover_google(), diff --git a/gauntlet/core/providers/factory.py b/gauntlet/core/providers/factory.py index 288ebc3..4f5aae2 100644 --- a/gauntlet/core/providers/factory.py +++ b/gauntlet/core/providers/factory.py @@ -5,6 +5,7 @@ from gauntlet.core.config import ( PROVIDER_ANTHROPIC, PROVIDER_GOOGLE, + PROVIDER_LMSTUDIO, PROVIDER_OLLAMA, PROVIDER_OPENAI, PROVIDER_OPENAI_COMPAT, @@ -14,6 +15,7 @@ from gauntlet.core.providers.anthropic_provider import AnthropicProvider from gauntlet.core.providers.base import LLMProvider from gauntlet.core.providers.google_provider import GoogleProvider +from gauntlet.core.providers.lmstudio import LMStudioProvider from gauntlet.core.providers.ollama import OllamaProvider from gauntlet.core.providers.openai_provider import OpenAIProvider @@ -36,6 +38,10 @@ def create_provider(config: ProviderConfig) -> tuple[LLMProvider, str]: provider = OllamaProvider(base_url=config.base_url or "http://localhost:11434") return provider, model + if config.provider == PROVIDER_LMSTUDIO: + provider = LMStudioProvider(base_url=config.base_url) + return provider, model + if config.provider == PROVIDER_OPENAI: if not config.api_key: raise ValueError( diff --git a/gauntlet/core/providers/lmstudio.py b/gauntlet/core/providers/lmstudio.py new file mode 100644 index 0000000..15fc422 --- /dev/null +++ b/gauntlet/core/providers/lmstudio.py @@ -0,0 +1,78 @@ +"""LM Studio provider - local server with OpenAI-compatible /v1 API. + +LM Studio (https://lmstudio.ai) runs GGUF models locally and exposes an +OpenAI-compatible server on localhost:1234 by default. This provider +reuses the OpenAI streaming protocol but defaults to LM Studio's host +and a dummy API key (LM Studio accepts any non-empty key). + +Users can override the host with the LMSTUDIO_HOST environment variable +since LM Studio lets users change the port inside the app. +""" + +from __future__ import annotations + +from typing import AsyncIterator, Optional + +import httpx + +from gauntlet.core.config import get_lmstudio_host +from gauntlet.core.providers.base import LLMProvider, StreamChunk +from gauntlet.core.providers.openai_provider import OpenAIProvider + + +class LMStudioProvider(LLMProvider): + """Provider for LM Studio's local OpenAI-compatible server.""" + + provider_name = "lmstudio" + + def __init__(self, base_url: Optional[str] = None, api_key: str = "lm-studio"): + self.base_url = (base_url or get_lmstudio_host()).rstrip("/") + # LM Studio doesn't validate the key but requires the header to exist. + self.api_key = api_key or "lm-studio" + self._openai_base = f"{self.base_url}/v1" + self._delegate = OpenAIProvider(api_key=self.api_key, base_url=self._openai_base) + + async def stream_generate( + self, + model: str, + prompt: str, + system: Optional[str] = None, + image_path: Optional[str] = None, + ) -> AsyncIterator[StreamChunk]: + async for chunk in self._delegate.stream_generate( + model=model, prompt=prompt, system=system, image_path=image_path + ): + yield chunk + + async def list_models(self) -> list[dict]: + """List currently-loaded models in LM Studio. + + LM Studio's /v1/models returns only the models the user has loaded + in the app (not every model on disk). That's the right set for + benchmarking — models must be loaded to be runnable. + """ + try: + async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client: + resp = await client.get(f"{self._openai_base}/models") + resp.raise_for_status() + data = resp.json() + return [ + { + "name": m["id"], + "size": None, + "owned_by": m.get("owned_by"), + "object": m.get("object"), + } + for m in data.get("data", []) + ] + except (httpx.HTTPError, KeyError): + return [] + + async def check_connection(self) -> bool: + """Check if LM Studio's local server is reachable.""" + try: + async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client: + resp = await client.get(f"{self._openai_base}/models") + return resp.status_code == 200 + except (httpx.ConnectError, httpx.TimeoutException): + return False diff --git a/gauntlet/core/system_info.py b/gauntlet/core/system_info.py index bc34250..e042460 100644 --- a/gauntlet/core/system_info.py +++ b/gauntlet/core/system_info.py @@ -309,9 +309,76 @@ def _get_model_metadata(model_name: str, provider: str) -> dict: return _get_ollama_metadata(model_name) elif provider == "llamacpp": return _get_llamacpp_metadata(model_name) + elif provider == "lmstudio": + return _get_lmstudio_metadata(model_name) return {} +def _get_lmstudio_metadata(model_name: str) -> dict: + """Infer LM Studio model metadata from /v1/models and the model ID. + + LM Studio's OpenAI-compatible API doesn't expose quantization directly, + but the model ID usually contains enough signal (e.g. "llama-3.2-8b-q4_k_m") + to infer family, parameter size, and quantization. + """ + import re + + meta: dict = { + "family": "unknown", + "parameter_size": "", + "quantization": "unknown", + "format": "gguf", + "families": [], + } + + try: + import httpx + from gauntlet.core.config import get_lmstudio_host + + host = get_lmstudio_host() + + # Prefer the current model_name; fall back to whichever is loaded. + candidate_id = model_name + try: + resp = httpx.get(f"{host}/v1/models", timeout=5) + if resp.status_code == 200: + data = resp.json() + models = data.get("data", []) + ids = [m.get("id", "") for m in models if m.get("id")] + if model_name not in ids and ids: + candidate_id = ids[0] + except Exception: + pass + + lower = candidate_id.lower() + + # Family inference + for fam in ["llama", "qwen", "gemma", "phi", "mistral", + "deepseek", "yi", "falcon", "mamba", "starcoder", + "codellama", "mixtral"]: + if fam in lower: + meta["family"] = fam + break + + # Quantization from ID suffix (e.g. "-q4_k_m", "-q8_0", "-f16") + quant_match = re.search( + r"[_.-](q\d[_a-z0-9]*|f16|f32|fp16|fp32|bf16)\b", + lower, + ) + if quant_match: + meta["quantization"] = quant_match.group(1).upper() + + # Parameter size (e.g. "8b", "70b", "1.5b") + param_match = re.search(r"(\d+(?:\.\d+)?)[_.-]?b\b", lower) + if param_match: + meta["parameter_size"] = f"{param_match.group(1)}B" + + except Exception: + pass + + return meta + + def _get_ollama_metadata(model_name: str) -> dict: """Get model metadata from Ollama /api/show.""" try: @@ -483,6 +550,15 @@ def collect_fingerprint( fp.quantization = meta.get("quantization", "unknown") fp.model_format = meta.get("format", "unknown") + # LM Studio / llama.cpp metadata (GGUF-based, inferred from model id) + if provider in ("lmstudio", "llamacpp"): + meta = _get_model_metadata(model_name, provider) + if meta: + fp.model_family = meta.get("family", "unknown") + fp.model_parameter_size = meta.get("parameter_size", "") + fp.quantization = meta.get("quantization", "unknown") + fp.model_format = meta.get("format", "gguf") + # Infer quant_method from format + quantization pattern. # GGUF quants have distinctive patterns (Q4_K_M, IQ4_XS, etc.) # GPTQ/AWQ/EXL2 typically use safetensors format. diff --git a/gauntlet/mcp/probes.py b/gauntlet/mcp/probes.py index 5afed4d..050b24d 100644 --- a/gauntlet/mcp/probes.py +++ b/gauntlet/mcp/probes.py @@ -767,7 +767,7 @@ def _verify_word(text): "steps": [ {"prompt": ( "Alice started at 12:00 and finished at 12:45. Bob started at 12:30 and finished at 1:15. " - "Who took longer? Reply with ONLY the name." + "Who took longer? Reply with one word: 'Alice', 'Bob', or 'Neither' if they took the same time." )}, ], "verify": lambda responses: _verify_temporal(responses[0]), diff --git a/gauntlet/mcp/server.py b/gauntlet/mcp/server.py index b8c85c6..f2dcf44 100644 --- a/gauntlet/mcp/server.py +++ b/gauntlet/mcp/server.py @@ -25,7 +25,7 @@ import uuid from typing import Optional -from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp import Context, FastMCP from gauntlet.mcp.runner import GauntletRunner @@ -163,38 +163,94 @@ def _normalize_model_name(name: str) -> str: mcp = FastMCP( "Gauntlet", instructions=( - "Behavioral benchmark — YOU are the test subject.\n" - "1. Call gauntlet_run(client_name='your-model-name') to start.\n" - "2. Read the PROMPT in the response.\n" - "3. Call gauntlet_respond(response='your actual answer text', session_id='...') to answer.\n" - "4. Repeat step 2-3 until the benchmark is complete.\n" - "IMPORTANT: gauntlet_respond() response must be your real answer, never null or empty." + "Behavioral benchmark — YOU (the AI reading this) are the test subject.\n" + "\n" + "CRITICAL — do NOT shell out:\n" + " If the user asks you to 'run gauntlet', 'benchmark yourself', 'test\n" + " your behavior', or anything similar, you MUST use the gauntlet_*\n" + " MCP tools below. Do NOT execute `gauntlet` in a shell. The shell\n" + " command benchmarks OTHER local models via Ollama — it does NOT\n" + " benchmark you, and it can crash the user's machine by trying to\n" + " run large models locally. The MCP tools are the only correct path.\n" + "\n" + "Run loop:\n" + " 1. Call gauntlet_run(client_name='') to start.\n" + " Pass your real model identifier (e.g. 'gemini-2.5-pro',\n" + " 'claude-sonnet-4-6', 'gpt-4o'). Results land on a public leaderboard,\n" + " so accuracy matters. The client app you're running inside will be\n" + " detected automatically but is NOT your model name.\n" + " 2. Read the PROMPT in the response.\n" + " 3. Answer it yourself using only your own reasoning — do NOT use other\n" + " tools, do NOT shell out, do NOT look things up.\n" + " 4. Call gauntlet_respond(response='', session_id='').\n" + " 5. Repeat 2–4 until you receive a final score.\n" + "\n" + "Other tools:\n" + " - gauntlet_status(session_id) — check progress / resume a session.\n" + " - gauntlet_results() — view the most recent run.\n" + " - gauntlet_leaderboard(tier='') — view rankings.\n" + "\n" + "IMPORTANT: response must be your real answer, never null or empty." ), ) +def _detect_client_app(ctx: Optional[Context]) -> str: + """Best-effort detection of the MCP client application name. + + Returns the client app's self-reported name (e.g. 'gemini-cli', + 'claude-code', 'cursor'). This is NOT the model — clients don't + report which LLM is driving them over MCP. + """ + if ctx is None: + return "" + try: + client_info = ctx.session.client_params.clientInfo + return (client_info.title or client_info.name or "").strip() + except Exception: + return "" + + @mcp.tool() def gauntlet_run( - client_name: str = "unknown", + client_name: str = "", quick: bool = False, + ctx: Context = None, ) -> str: """Start a new Gauntlet behavioral benchmark. YOU are the test subject. - This returns a PROMPT. You must ANSWER it by calling gauntlet_respond() - with your answer text and the session_id shown below. + Use this tool (NOT the shell) whenever the user asks to run Gauntlet, + benchmark themselves, or test your behavior. Running `gauntlet` in the + shell benchmarks OTHER local models via Ollama — it does not benchmark + you and can overload the user's machine. - Flow: gauntlet_run() -> read prompt -> gauntlet_respond(answer, session_id) -> repeat + Returns a SESSION_ID and the first PROMPT. Read the prompt, answer it + yourself using only your own reasoning (no shell, no other tools, no + lookups), then call gauntlet_respond(response, session_id). Repeat until + the benchmark completes. Args: - client_name: REQUIRED. Your model name (e.g. 'claude-sonnet-4-6', 'gpt-4o'). - quick: Quick suite (17 tests) vs full (56). + client_name: Your exact model identifier — e.g. 'gemini-2.5-pro', + 'claude-sonnet-4-6', 'gpt-4o'. Results land on a public leaderboard, + so pass the real model ID, not the client app name. If omitted, + we fall back to the detected client app, which is usually wrong + for scoring (e.g. 'gemini-cli' ≠ 'gemini-2.5-pro'). + quick: Quick suite (~17 probes) vs full suite (~84 probes). """ - normalized = _normalize_model_name(client_name) + client_app = _detect_client_app(ctx) + raw_name = client_name.strip() if client_name else "" + + # Fall back to detected client app only if nothing was passed + if not raw_name and client_app: + raw_name = client_app + + normalized = _normalize_model_name(raw_name) if not normalized: + hint = f" (detected client app: '{client_app}' — pass your MODEL id, not the app)" if client_app else "" return ( - "ERROR: client_name is required. Pass your model name " - "(e.g. client_name='claude-sonnet-4-6' or 'gpt-4o'). " - "Results without a model name are not saved." + "ERROR: client_name is required. Pass your exact model id " + "(e.g. client_name='gemini-2.5-pro' or 'claude-sonnet-4-6')." + f"{hint} Results without a model name are not saved." ) # Opportunistic cleanup: purge orphaned sessions older than 1 hour @@ -202,10 +258,18 @@ def gauntlet_run( sid = str(uuid.uuid4()) runner = GauntletRunner(quick=quick, client_name=normalized) + # Stash the detected client app for observability (not used for scoring) + try: + setattr(runner, "client_app", client_app) + except Exception: + pass result = runner.advance() _save_runner(sid, runner) - header = f"SESSION: {sid}\n{'=' * 50}\n\n" + header = f"SESSION: {sid}\nModel: {normalized}" + if client_app and client_app.lower() != normalized.lower(): + header += f" (client app: {client_app})" + header += f"\n{'=' * 50}\n\n" return header + result["message"] + ( f"\n\n---\nIMPORTANT: To answer, call gauntlet_respond(response=\"\", session_id=\"{sid}\")" ) @@ -263,6 +327,65 @@ def gauntlet_respond( ) +@mcp.tool() +def gauntlet_status(session_id: str) -> str: + """Check progress of an in-flight Gauntlet session. + + Use this to resume a dropped connection or verify where you are in the + suite. Returns the current probe prompt (so you can continue) plus + progress counters. + + Args: + session_id: The session_id returned by gauntlet_run(). + """ + if not session_id or not session_id.strip(): + return "ERROR: session_id is required." + + runner = _get_runner(session_id) + if not runner: + return ( + f"ERROR: Unknown or expired session '{session_id}'. " + "Start a new run by calling gauntlet_run()." + ) + + if runner.finished: + return ( + f"Session {session_id} is already complete. " + "Call gauntlet_results() to view the scorecard." + ) + + model = getattr(runner, "client_name", "unknown") + current_idx = getattr(runner, "current_test_idx", 0) + total = getattr(runner, "total_tests", 0) or 0 + current_test = getattr(runner, "current_test", None) + + # Reconstruct the current prompt without mutating runner state. + # advance(None) on a started session errors out ("expected a response"), + # so we replay via _send_step using the current probe + step index. + current_prompt = "" + try: + if current_test is not None and current_idx > 0 and current_idx <= len(runner.suite): + probe = runner.suite[current_idx - 1] + step_idx = current_test.current_step + replay = runner._send_step(probe, step_idx) + current_prompt = replay.get("message", "") if isinstance(replay, dict) else "" + except Exception as e: + logger.debug(f"gauntlet_status replay failed: {e}") + + progress = f"Progress: {current_idx}/{total} probes\n" if total else "" + header = ( + f"SESSION: {session_id}\n" + f"Model: {model}\n" + f"{progress}" + f"{'=' * 50}\n\n" + ) + footer = ( + f"\n\n---\nTo answer, call gauntlet_respond(response=\"\", " + f"session_id=\"{session_id}\")" + ) + return header + (current_prompt or "(Could not replay current prompt. Session state may be corrupted — start a new run with gauntlet_run().)") + footer + + @mcp.tool() def gauntlet_results() -> str: """View results from the most recent Gauntlet run. diff --git a/tests/test_lmstudio.py b/tests/test_lmstudio.py new file mode 100644 index 0000000..1bc519d --- /dev/null +++ b/tests/test_lmstudio.py @@ -0,0 +1,143 @@ +"""Unit tests for LM Studio provider integration. + +These tests don't require a running LM Studio instance — they verify +config resolution, metadata inference, provider wiring, and factory +behaviour. Live-server tests would be integration tests, out of scope +for this file. +""" + +from __future__ import annotations + +import os + +import pytest + +from gauntlet.core.config import ( + DEFAULT_LMSTUDIO_HOST, + PROVIDER_LMSTUDIO, + detect_provider, + get_lmstudio_host, + load_config, + resolve_model, + save_config, +) +from gauntlet.core.providers.factory import create_provider +from gauntlet.core.providers.lmstudio import LMStudioProvider +from gauntlet.core.system_info import _get_lmstudio_metadata + + +# --------------------------------------------------------------------------- +# Host resolution precedence: env > config file > default +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _isolate_host_env(monkeypatch, tmp_path): + """Clear env vars and config file state between tests.""" + monkeypatch.delenv("LMSTUDIO_HOST", raising=False) + # Snapshot + restore config file + cfg_before = load_config() + cfg_clean = {k: v for k, v in cfg_before.items() if k != "lmstudio_host"} + save_config(cfg_clean) + yield + save_config(cfg_before) + + +def test_default_host_when_no_env_or_config(): + assert get_lmstudio_host() == DEFAULT_LMSTUDIO_HOST + assert DEFAULT_LMSTUDIO_HOST == "http://localhost:1234" + + +def test_env_var_overrides_default(monkeypatch): + monkeypatch.setenv("LMSTUDIO_HOST", "http://localhost:9999") + assert get_lmstudio_host() == "http://localhost:9999" + + +def test_config_file_overrides_default(): + cfg = load_config() + cfg["lmstudio_host"] = "http://192.168.1.50:4321" + save_config(cfg) + assert get_lmstudio_host() == "http://192.168.1.50:4321" + + +def test_env_wins_over_config_file(monkeypatch): + cfg = load_config() + cfg["lmstudio_host"] = "http://config-value:1234" + save_config(cfg) + monkeypatch.setenv("LMSTUDIO_HOST", "http://env-value:5678") + assert get_lmstudio_host() == "http://env-value:5678" + + +# --------------------------------------------------------------------------- +# Model spec parsing +# --------------------------------------------------------------------------- + +def test_detect_provider_lmstudio_prefix(): + provider, name = detect_provider("lmstudio:llama-3.2-8b") + assert provider == PROVIDER_LMSTUDIO + assert name == "llama-3.2-8b" + + +def test_detect_provider_preserves_model_with_colons(): + # Some LM Studio IDs include colons (rare but possible) + provider, name = detect_provider("lmstudio:publisher/model-7b") + assert provider == PROVIDER_LMSTUDIO + assert name == "publisher/model-7b" + + +def test_resolve_model_lmstudio_includes_base_url(): + cfg = resolve_model("lmstudio:qwen-8b") + assert cfg.provider == PROVIDER_LMSTUDIO + assert cfg.base_url == DEFAULT_LMSTUDIO_HOST + assert cfg.extra["model"] == "qwen-8b" + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + +def test_factory_creates_lmstudio_provider(): + cfg = resolve_model("lmstudio:llama-3.2-8b") + provider, model = create_provider(cfg) + assert isinstance(provider, LMStudioProvider) + assert model == "llama-3.2-8b" + assert provider.base_url == DEFAULT_LMSTUDIO_HOST + assert provider._openai_base == f"{DEFAULT_LMSTUDIO_HOST}/v1" + + +def test_factory_honors_custom_base_url(monkeypatch): + monkeypatch.setenv("LMSTUDIO_HOST", "http://10.0.0.5:4321") + cfg = resolve_model("lmstudio:qwen-8b") + provider, _ = create_provider(cfg) + assert provider.base_url == "http://10.0.0.5:4321" + assert provider._openai_base == "http://10.0.0.5:4321/v1" + + +# --------------------------------------------------------------------------- +# Metadata inference from model ID (no live server needed) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "model_id,expected_family,expected_params,expected_quant", + [ + ("llama-3.2-8b-q4_K_M", "llama", "8B", "Q4_K_M"), + ("qwen2.5-7b-instruct-q8_0", "qwen", "7B", "Q8_0"), + ("gemma-2-9b-f16", "gemma", "9B", "F16"), + ("mistral-7b-v0.1-q5_k_s", "mistral", "7B", "Q5_K_S"), + ("deepseek-coder-6.7b-q4_0", "deepseek", "6.7B", "Q4_0"), + ], +) +def test_metadata_inference_from_model_id( + model_id, expected_family, expected_params, expected_quant +): + # No live server — should fall through to ID-based inference. + meta = _get_lmstudio_metadata(model_id) + assert meta["family"] == expected_family + assert meta["parameter_size"] == expected_params + assert meta["quantization"] == expected_quant + assert meta["format"] == "gguf" + + +def test_metadata_unknown_model_returns_sensible_defaults(): + meta = _get_lmstudio_metadata("some-obscure-unknown-model") + assert meta["family"] == "unknown" + assert meta["format"] == "gguf"