diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 5ba269c2..4f9c8dd8 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1527,6 +1527,22 @@ def select_provider_and_model(args=None): all_providers = [(p.slug, p.tui_desc) for p in CANONICAL_PROVIDERS] def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]: + from hermes_cli.config import read_raw_config + + def _identity(entry): + return ( + str(entry.get("provider_key", "") or "").strip(), + str(entry.get("name", "") or "").strip(), + str(entry.get("base_url", "") or "").strip().rstrip("/"), + str(entry.get("model", "") or "").strip(), + ) + + raw_api_key_refs = {} + for raw_entry in get_compatible_custom_providers(read_raw_config()): + raw_api_key = str(raw_entry.get("api_key", "") or "").strip() + if "${" in raw_api_key: + raw_api_key_refs[_identity(raw_entry)] = raw_api_key + custom_provider_map = {} for entry in get_compatible_custom_providers(cfg): if not isinstance(entry, dict): @@ -1550,6 +1566,7 @@ def select_provider_and_model(args=None): "model": entry.get("model", ""), "api_mode": entry.get("api_mode", ""), "provider_key": provider_key, + "api_key_ref": raw_api_key_refs.get(_identity(entry), ""), } return custom_provider_map @@ -2782,6 +2799,19 @@ def _auto_provider_name(base_url: str) -> str: return name +def _custom_provider_api_key_config_value(provider_info, resolved_api_key=""): + """Return the value that should be persisted for a custom provider key.""" + api_key_ref = str(provider_info.get("api_key_ref", "") or "").strip() + if api_key_ref: + return api_key_ref + + key_env = str(provider_info.get("key_env", "") or "").strip() + if key_env and not str(provider_info.get("api_key", "") or "").strip(): + return f"${{{key_env}}}" + + return str(resolved_api_key or "").strip() + + def _save_custom_provider( base_url, api_key="", model="", context_length=None, name=None ): @@ -2923,6 +2953,7 @@ def _model_flow_named_custom(config, provider_info): # Resolve key from env var if api_key not set directly if not api_key and key_env: api_key = os.environ.get(key_env, "") + config_api_key = _custom_provider_api_key_config_value(provider_info, api_key) print(f" Provider: {name}") print(f" URL: {base_url}") @@ -3019,8 +3050,8 @@ def _model_flow_named_custom(config, provider_info): else: model["provider"] = "custom" model["base_url"] = base_url - if api_key: - model["api_key"] = api_key + if config_api_key: + model["api_key"] = config_api_key # Apply api_mode from custom_providers entry, or clear stale value custom_api_mode = provider_info.get("api_mode", "") if custom_api_mode: @@ -3038,15 +3069,15 @@ def _model_flow_named_custom(config, provider_info): provider_entry = providers_cfg.get(provider_key) if isinstance(provider_entry, dict): provider_entry["default_model"] = model_name - if api_key and not str(provider_entry.get("api_key", "") or "").strip(): - provider_entry["api_key"] = api_key + if config_api_key and not str(provider_entry.get("api_key", "") or "").strip(): + provider_entry["api_key"] = config_api_key if key_env and not str(provider_entry.get("key_env", "") or "").strip(): provider_entry["key_env"] = key_env cfg["providers"] = providers_cfg save_config(cfg) else: # Save model name to the custom_providers entry for next time - _save_custom_provider(base_url, api_key, model_name) + _save_custom_provider(base_url, config_api_key, model_name) print(f"\n✅ Model set to: {model_name}") print(f" Provider: {name} ({base_url})") diff --git a/tests/hermes_cli/test_custom_provider_model_switch.py b/tests/hermes_cli/test_custom_provider_model_switch.py index a0123670..8235c930 100644 --- a/tests/hermes_cli/test_custom_provider_model_switch.py +++ b/tests/hermes_cli/test_custom_provider_model_switch.py @@ -52,7 +52,12 @@ class TestCustomProviderModelSwitch: _model_flow_named_custom({}, provider_info) # fetch_api_models MUST be called even though model was saved - mock_fetch.assert_called_once_with("sk-test", "https://vllm.example.com/v1", timeout=8.0) + mock_fetch.assert_called_once_with( + "sk-test", + "https://vllm.example.com/v1", + timeout=8.0, + api_mode=None, + ) def test_can_switch_to_different_model(self, config_home): """User selects a different model than the saved one.""" @@ -173,3 +178,82 @@ class TestCustomProviderModelSwitch: model = config.get("model") assert isinstance(model, dict) assert "api_mode" not in model, "Stale api_mode should be removed" + + def test_env_template_api_key_is_preserved_in_model_config(self, config_home, monkeypatch): + """Selecting an env-backed custom provider must not inline the secret.""" + import yaml + from hermes_cli.main import _model_flow_named_custom + + config_path = config_home / "config.yaml" + config_path.write_text( + "model:\n" + " default: old-model\n" + " provider: openrouter\n" + "custom_providers:\n" + "- name: Example Provider\n" + " base_url: https://api.example-provider.test/v1\n" + " api_key: ${EXAMPLE_PROVIDER_API_KEY}\n" + " model: qwen3.6-35b-fast\n" + ) + monkeypatch.setenv("EXAMPLE_PROVIDER_API_KEY", "sk-live-example-provider") + + provider_info = { + "name": "Example Provider", + "base_url": "https://api.example-provider.test/v1", + "api_key": "sk-live-example-provider", + "api_key_ref": "${EXAMPLE_PROVIDER_API_KEY}", + "model": "qwen3.6-35b-fast", + } + + with patch("hermes_cli.models.fetch_api_models", return_value=["qwen3.6-35b-fast"]) as mock_fetch, \ + patch.dict("sys.modules", {"simple_term_menu": None}), \ + patch("builtins.input", return_value="1"), \ + patch("builtins.print"): + _model_flow_named_custom({}, provider_info) + + mock_fetch.assert_called_once_with( + "sk-live-example-provider", + "https://api.example-provider.test/v1", + timeout=8.0, + api_mode=None, + ) + config = yaml.safe_load(config_path.read_text()) or {} + assert config["model"]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}" + assert config["custom_providers"][0]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}" + assert "sk-live-example-provider" not in config_path.read_text() + + def test_key_env_custom_provider_persists_reference_not_secret(self, config_home, monkeypatch): + """key_env custom providers should also avoid writing plaintext keys.""" + import yaml + from hermes_cli.main import _model_flow_named_custom + + config_path = config_home / "config.yaml" + config_path.write_text( + "model:\n" + " default: old-model\n" + "custom_providers:\n" + "- name: Example Provider\n" + " base_url: https://api.example-provider.test/v1\n" + " key_env: EXAMPLE_PROVIDER_API_KEY\n" + " model: qwen3.6-35b-fast\n" + ) + monkeypatch.setenv("EXAMPLE_PROVIDER_API_KEY", "sk-live-example-provider") + + provider_info = { + "name": "Example Provider", + "base_url": "https://api.example-provider.test/v1", + "api_key": "", + "key_env": "EXAMPLE_PROVIDER_API_KEY", + "model": "qwen3.6-35b-fast", + } + + with patch("hermes_cli.models.fetch_api_models", return_value=["qwen3.6-35b-fast"]), \ + patch.dict("sys.modules", {"simple_term_menu": None}), \ + patch("builtins.input", return_value="1"), \ + patch("builtins.print"): + _model_flow_named_custom({}, provider_info) + + config = yaml.safe_load(config_path.read_text()) or {} + assert config["model"]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}" + assert config["custom_providers"][0]["key_env"] == "EXAMPLE_PROVIDER_API_KEY" + assert "sk-live-example-provider" not in config_path.read_text()