diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index a652cb86..0c626366 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -59,6 +59,28 @@ def test_write_json_returns_false_on_broken_pipe(monkeypatch): assert server.write_json({"ok": True}) is False +def test_dispatch_rejects_non_object_request(): + resp = server.dispatch([]) + + assert resp == { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32600, "message": "invalid request: expected an object"}, + } + + +def test_dispatch_rejects_non_object_params(): + resp = server.dispatch( + {"id": "1", "method": "session.create", "params": []} + ) + + assert resp == { + "jsonrpc": "2.0", + "id": "1", + "error": {"code": -32602, "message": "invalid params: expected an object"}, + } + + def test_load_enabled_toolsets_prefers_tui_env(monkeypatch): monkeypatch.setenv("HERMES_TUI_TOOLSETS", "web, terminal, ,memory") diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 47f25e7e..f5035495 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -417,11 +417,35 @@ def method(name: str): return dec +def _normalize_request(req: Any) -> tuple[Any, str, dict] | dict: + """Validate a JSON-RPC request enough for safe local dispatch.""" + if not isinstance(req, dict): + return _err(None, -32600, "invalid request: expected an object") + + rid = req.get("id") + method = req.get("method") + if not isinstance(method, str) or not method: + return _err(rid, -32600, "invalid request: method must be a non-empty string") + + params = req.get("params", {}) + if params is None: + params = {} + elif not isinstance(params, dict): + return _err(rid, -32602, "invalid params: expected an object") + + return rid, method, params + + def handle_request(req: dict) -> dict | None: - fn = _methods.get(req.get("method", "")) + normalized = _normalize_request(req) + if isinstance(normalized, dict): + return normalized + + rid, method, params = normalized + fn = _methods.get(method) if not fn: - return _err(req.get("id"), -32601, f"unknown method: {req.get('method')}") - return fn(req.get("id"), req.get("params", {})) + return _err(rid, -32601, f"unknown method: {method}") + return fn(rid, params) def dispatch(req: dict, transport: Optional[Transport] = None) -> dict | None: @@ -439,7 +463,12 @@ def dispatch(req: dict, transport: Optional[Transport] = None) -> dict | None: t = transport or _stdio_transport token = bind_transport(t) try: - if req.get("method") not in _LONG_HANDLERS: + normalized = _normalize_request(req) + if isinstance(normalized, dict): + return normalized + + _rid, method, _params = normalized + if method not in _LONG_HANDLERS: return handle_request(req) # Snapshot the context so the pool worker sees the bound transport.