feat: sort tool search results by score and add corresponding unit test
This commit is contained in:
parent
b867171291
commit
9bdfcd1b93
@ -509,19 +509,24 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
result = resp.get("result", {})
|
||||
|
||||
# Format results for the model — keep it concise
|
||||
formatted = []
|
||||
scored_entries = []
|
||||
for ctx_type in ("memories", "resources", "skills"):
|
||||
items = result.get(ctx_type, [])
|
||||
for item in items:
|
||||
raw_score = item.get("score")
|
||||
sort_score = raw_score if raw_score is not None else 0.0
|
||||
entry = {
|
||||
"uri": item.get("uri", ""),
|
||||
"type": ctx_type.rstrip("s"),
|
||||
"score": round(item.get("score", 0), 3),
|
||||
"score": round(raw_score, 3) if raw_score is not None else 0.0,
|
||||
"abstract": item.get("abstract", ""),
|
||||
}
|
||||
if item.get("relations"):
|
||||
entry["related"] = [r.get("uri") for r in item["relations"][:3]]
|
||||
formatted.append(entry)
|
||||
scored_entries.append((sort_score, entry))
|
||||
|
||||
scored_entries.sort(key=lambda x: x[0], reverse=True)
|
||||
formatted = [entry for _, entry in scored_entries]
|
||||
|
||||
return json.dumps({
|
||||
"results": formatted,
|
||||
|
||||
62
tests/plugins/memory/test_openviking_provider.py
Normal file
62
tests/plugins/memory/test_openviking_provider.py
Normal file
@ -0,0 +1,62 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from plugins.memory.openviking import OpenVikingMemoryProvider
|
||||
|
||||
|
||||
def test_tool_search_sorts_by_raw_score_across_buckets():
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._client.post.return_value = {
|
||||
"result": {
|
||||
"memories": [
|
||||
{"uri": "viking://memories/1", "score": 0.9003, "abstract": "memory result"},
|
||||
],
|
||||
"resources": [
|
||||
{"uri": "viking://resources/1", "score": 0.9004, "abstract": "resource result"},
|
||||
],
|
||||
"skills": [
|
||||
{"uri": "viking://skills/1", "score": 0.8999, "abstract": "skill result"},
|
||||
],
|
||||
"total": 3,
|
||||
}
|
||||
}
|
||||
|
||||
result = json.loads(provider._tool_search({"query": "ranking"}))
|
||||
|
||||
assert [entry["uri"] for entry in result["results"]] == [
|
||||
"viking://resources/1",
|
||||
"viking://memories/1",
|
||||
"viking://skills/1",
|
||||
]
|
||||
assert [entry["score"] for entry in result["results"]] == [0.9, 0.9, 0.9]
|
||||
assert result["total"] == 3
|
||||
|
||||
|
||||
def test_tool_search_sorts_missing_raw_score_after_negative_scores():
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._client.post.return_value = {
|
||||
"result": {
|
||||
"memories": [
|
||||
{"uri": "viking://memories/missing", "abstract": "missing score"},
|
||||
],
|
||||
"resources": [
|
||||
{"uri": "viking://resources/negative", "score": -0.25, "abstract": "negative score"},
|
||||
],
|
||||
"skills": [
|
||||
{"uri": "viking://skills/positive", "score": 0.1, "abstract": "positive score"},
|
||||
],
|
||||
"total": 3,
|
||||
}
|
||||
}
|
||||
|
||||
result = json.loads(provider._tool_search({"query": "ranking"}))
|
||||
|
||||
assert [entry["uri"] for entry in result["results"]] == [
|
||||
"viking://skills/positive",
|
||||
"viking://memories/missing",
|
||||
"viking://resources/negative",
|
||||
]
|
||||
assert [entry["score"] for entry in result["results"]] == [0.1, 0.0, -0.25]
|
||||
assert result["total"] == 3
|
||||
Loading…
Reference in New Issue
Block a user