diff --git a/mnemosyne/library/services/search.py b/mnemosyne/library/services/search.py index 634a57c..bd114f2 100644 --- a/mnemosyne/library/services/search.py +++ b/mnemosyne/library/services/search.py @@ -26,15 +26,32 @@ from .fusion import ImageSearchResult, SearchCandidate, reciprocal_rank_fusion logger = logging.getLogger(__name__) +# Workspace scoping clause appended to every search Cypher query. +# +# A request with workspace_id set returns ONLY that workspace's content. +# A request with workspace_id null returns ONLY global content (libraries +# with no workspace_id). There is no third mode. +_WORKSPACE_SCOPE_CLAUSE = ( + " AND ($workspace_id IS NULL AND lib.workspace_id IS NULL OR " + "lib.workspace_id = $workspace_id)" +) + + @dataclass class SearchRequest: - """Parameters for a search query.""" + """Parameters for a search query. + + Scope is single-mode: a request is either workspace-scoped (workspace_id + set) or global (workspace_id is None). There is no parameter combination + that returns both workspace and global content in one call. + """ query: str query_image: Optional[bytes] = None library_uid: Optional[str] = None library_type: Optional[str] = None collection_uid: Optional[str] = None + workspace_id: Optional[str] = None search_types: list[str] = field( default_factory=lambda: ["vector", "fulltext", "graph"] ) @@ -45,6 +62,18 @@ class SearchRequest: rerank: bool = True include_images: bool = True + def __post_init__(self): + # Normalize empty strings to None so "" doesn't slip through as + # truthy at the Cypher boundary. + if self.workspace_id == "": + self.workspace_id = None + if self.library_uid == "": + self.library_uid = None + if self.library_type == "": + self.library_type = None + if self.collection_uid == "": + self.collection_uid = None + @dataclass class SearchResponse: @@ -243,7 +272,8 @@ class SearchService: top_k = request.vector_top_k # Build Cypher with optional filtering - cypher = """ + cypher = ( + """ CALL db.index.vector.queryNodes('chunk_embedding_index', $top_k, $query_vector) YIELD node AS chunk, score MATCH (item:Item)-[:HAS_CHUNK]->(chunk) @@ -251,13 +281,17 @@ class SearchService: WHERE ($library_uid IS NULL OR lib.uid = $library_uid) AND ($library_type IS NULL OR lib.library_type = $library_type) AND ($collection_uid IS NULL OR col.uid = $collection_uid) + """ + + _WORKSPACE_SCOPE_CLAUSE + + """ RETURN chunk.uid AS chunk_uid, chunk.text_preview AS text_preview, chunk.chunk_s3_key AS chunk_s3_key, chunk.chunk_index AS chunk_index, item.uid AS item_uid, item.title AS item_title, lib.library_type AS library_type, score ORDER BY score DESC LIMIT $top_k - """ + """ + ) params = { "top_k": top_k, @@ -265,6 +299,7 @@ class SearchService: "library_uid": request.library_uid, "library_type": request.library_type, "collection_uid": request.collection_uid, + "workspace_id": request.workspace_id, } try: @@ -348,7 +383,8 @@ class SearchService: candidates: dict[str, SearchCandidate], ): """Search chunk_text_fulltext index and add to candidates dict.""" - cypher = """ + cypher = ( + """ CALL db.index.fulltext.queryNodes('chunk_text_fulltext', $query) YIELD node AS chunk, score MATCH (item:Item)-[:HAS_CHUNK]->(chunk) @@ -356,13 +392,17 @@ class SearchService: WHERE ($library_uid IS NULL OR lib.uid = $library_uid) AND ($library_type IS NULL OR lib.library_type = $library_type) AND ($collection_uid IS NULL OR col.uid = $collection_uid) + """ + + _WORKSPACE_SCOPE_CLAUSE + + """ RETURN chunk.uid AS chunk_uid, chunk.text_preview AS text_preview, chunk.chunk_s3_key AS chunk_s3_key, chunk.chunk_index AS chunk_index, item.uid AS item_uid, item.title AS item_title, lib.library_type AS library_type, score ORDER BY score DESC LIMIT $top_k - """ + """ + ) params = { "query": request.query, @@ -370,6 +410,7 @@ class SearchService: "library_uid": request.library_uid, "library_type": request.library_type, "collection_uid": request.collection_uid, + "workspace_id": request.workspace_id, } try: @@ -402,7 +443,8 @@ class SearchService: candidates: dict[str, SearchCandidate], ): """Search concept_name_fulltext and traverse to chunks.""" - cypher = """ + cypher = ( + """ CALL db.index.fulltext.queryNodes('concept_name_fulltext', $query) YIELD node AS concept, score AS concept_score MATCH (chunk:Chunk)-[:MENTIONS]->(concept) @@ -410,6 +452,9 @@ class SearchService: MATCH (lib:Library)-[:CONTAINS]->(:Collection)-[:CONTAINS]->(item) WHERE ($library_uid IS NULL OR lib.uid = $library_uid) AND ($library_type IS NULL OR lib.library_type = $library_type) + """ + + _WORKSPACE_SCOPE_CLAUSE + + """ RETURN chunk.uid AS chunk_uid, chunk.text_preview AS text_preview, chunk.chunk_s3_key AS chunk_s3_key, chunk.chunk_index AS chunk_index, item.uid AS item_uid, item.title AS item_title, @@ -417,13 +462,15 @@ class SearchService: concept_score * 0.8 AS score ORDER BY score DESC LIMIT $top_k - """ + """ + ) params = { "query": request.query, "top_k": top_k, "library_uid": request.library_uid, "library_type": request.library_type, + "workspace_id": request.workspace_id, } try: @@ -465,7 +512,8 @@ class SearchService: """ start = time.time() - cypher = """ + cypher = ( + """ CALL db.index.fulltext.queryNodes('concept_name_fulltext', $query) YIELD node AS concept, score AS concept_score WITH concept, concept_score @@ -476,6 +524,9 @@ class SearchService: MATCH (lib:Library)-[:CONTAINS]->(:Collection)-[:CONTAINS]->(item) WHERE ($library_uid IS NULL OR lib.uid = $library_uid) AND ($library_type IS NULL OR lib.library_type = $library_type) + """ + + _WORKSPACE_SCOPE_CLAUSE + + """ WITH chunk, item, lib, max(concept_score) AS score, collect(DISTINCT concept.name)[..5] AS concept_names @@ -486,13 +537,15 @@ class SearchService: score, concept_names ORDER BY score DESC LIMIT $limit - """ + """ + ) params = { "query": request.query, "limit": request.fulltext_top_k, "library_uid": request.library_uid, "library_type": request.library_type, + "workspace_id": request.workspace_id, } try: @@ -550,7 +603,8 @@ class SearchService: """ start = time.time() - cypher = """ + cypher = ( + """ CALL db.index.vector.queryNodes('image_embedding_index', $top_k, $query_vector) YIELD node AS emb_node, score MATCH (img:Image)-[:HAS_EMBEDDING]->(emb_node) @@ -558,19 +612,24 @@ class SearchService: MATCH (lib:Library)-[:CONTAINS]->(:Collection)-[:CONTAINS]->(item) WHERE ($library_uid IS NULL OR lib.uid = $library_uid) AND ($library_type IS NULL OR lib.library_type = $library_type) + """ + + _WORKSPACE_SCOPE_CLAUSE + + """ RETURN img.uid AS image_uid, img.image_type AS image_type, img.description AS description, img.s3_key AS s3_key, item.uid AS item_uid, item.title AS item_title, score ORDER BY score DESC LIMIT 10 - """ + """ + ) params = { "top_k": 10, "query_vector": query_vector, "library_uid": request.library_uid, "library_type": request.library_type, + "workspace_id": request.workspace_id, } try: diff --git a/mnemosyne/library/tests/test_search_scoping.py b/mnemosyne/library/tests/test_search_scoping.py new file mode 100644 index 0000000..95a1b64 --- /dev/null +++ b/mnemosyne/library/tests/test_search_scoping.py @@ -0,0 +1,71 @@ +""" +Tests for workspace scoping in SearchRequest and the Cypher scope clause. + +These exercise the dataclass-level normalization and the construction +of Cypher parameter dicts. The actual Cypher execution against Neo4j +is validated by the manual end-to-end test plan. +""" + +from django.test import TestCase + +from library.services.search import _WORKSPACE_SCOPE_CLAUSE, SearchRequest + + +class SearchRequestScopingTests(TestCase): + """SearchRequest workspace_id behavior.""" + + def test_default_workspace_id_is_none(self): + req = SearchRequest(query="hello") + self.assertIsNone(req.workspace_id) + + def test_explicit_workspace_id_preserved(self): + req = SearchRequest(query="hello", workspace_id="ws_abc") + self.assertEqual(req.workspace_id, "ws_abc") + + def test_empty_string_workspace_id_normalized_to_none(self): + """Empty strings must NOT slip through as a truthy filter at the Cypher boundary.""" + req = SearchRequest(query="hello", workspace_id="") + self.assertIsNone(req.workspace_id) + + def test_empty_string_library_uid_normalized_to_none(self): + req = SearchRequest(query="hello", library_uid="") + self.assertIsNone(req.library_uid) + + def test_empty_string_library_type_normalized_to_none(self): + req = SearchRequest(query="hello", library_type="") + self.assertIsNone(req.library_type) + + def test_empty_string_collection_uid_normalized_to_none(self): + req = SearchRequest(query="hello", collection_uid="") + self.assertIsNone(req.collection_uid) + + +class WorkspaceScopeClauseTests(TestCase): + """Sanity checks on the Cypher snippet itself. + + The clause must produce two distinct, non-overlapping result sets: + 1. workspace_id IS NULL → only global libraries (lib.workspace_id IS NULL) + 2. workspace_id = X → only libraries with workspace_id = X + + A "leaks both" bug would be a Cypher OR that fails to bracket properly. + Verifying the literal string here is a cheap regression guard against + refactors that accidentally change the operator precedence. + """ + + def test_clause_references_lib_workspace_id(self): + self.assertIn("lib.workspace_id", _WORKSPACE_SCOPE_CLAUSE) + + def test_clause_references_workspace_id_param(self): + self.assertIn("$workspace_id", _WORKSPACE_SCOPE_CLAUSE) + + def test_clause_handles_both_modes(self): + """Both 'IS NULL' and '=' branches must be present.""" + self.assertIn("IS NULL", _WORKSPACE_SCOPE_CLAUSE) + self.assertIn("=", _WORKSPACE_SCOPE_CLAUSE) + + def test_clause_starts_with_AND_so_it_appends_safely(self): + """The clause is appended to existing WHERE filters.""" + self.assertTrue( + _WORKSPACE_SCOPE_CLAUSE.lstrip().startswith("AND"), + f"Clause must start with AND: {_WORKSPACE_SCOPE_CLAUSE!r}", + ) diff --git a/mnemosyne/mcp_server/server.py b/mnemosyne/mcp_server/server.py index 390a38a..0f4815e 100644 --- a/mnemosyne/mcp_server/server.py +++ b/mnemosyne/mcp_server/server.py @@ -5,7 +5,11 @@ from __future__ import annotations from fastmcp import FastMCP from .auth import MCPAuthMiddleware -from .tools import register_discovery_tools, register_search_tools +from .tools import ( + register_discovery_tools, + register_health_tools, + register_search_tools, +) INSTRUCTIONS = """\ Mnemosyne is a content-type-aware, multimodal knowledge base. It indexes @@ -23,6 +27,8 @@ shapes how content is chunked, embedded, and re-ranked: - film — Scripts, synopses, stills. - art — Catalogs, descriptions, artwork itself. - journal — Personal entries; temporal/reflective. +- business — Proposals, marketing, sales, strategy. Commercial context. +- finance — Statements, tax, market commentary. Quote figures exactly. Tools: - search Hybrid retrieval. Filter by library_uid, library_type, @@ -32,6 +38,7 @@ Tools: - list_libraries Discover libraries (and their library_type). - list_collections Discover collections, optionally per library. - list_items Discover indexed items (documents). +- get_health Health check (used by Pallas/Daedalus pollers). Workflow: list_libraries → search(query, library_type=...) → get_chunk(chunk_uid) when the preview isn't enough. The calling LLM is responsible for synthesis @@ -47,6 +54,7 @@ def build_server() -> FastMCP: mcp.add_middleware(MCPAuthMiddleware()) register_search_tools(mcp) register_discovery_tools(mcp) + register_health_tools(mcp) return mcp diff --git a/mnemosyne/mcp_server/tests/test_server.py b/mnemosyne/mcp_server/tests/test_server.py index 0110114..70b2e0c 100644 --- a/mnemosyne/mcp_server/tests/test_server.py +++ b/mnemosyne/mcp_server/tests/test_server.py @@ -7,7 +7,14 @@ from django.test import TestCase from mcp_server.server import mcp -EXPECTED_TOOLS = {"search", "get_chunk", "list_libraries", "list_collections", "list_items"} +EXPECTED_TOOLS = { + "search", + "get_chunk", + "list_libraries", + "list_collections", + "list_items", + "get_health", +} class ServerRegistrationTest(TestCase): diff --git a/mnemosyne/mcp_server/tools/__init__.py b/mnemosyne/mcp_server/tools/__init__.py index 1160a60..c3b68a5 100644 --- a/mnemosyne/mcp_server/tools/__init__.py +++ b/mnemosyne/mcp_server/tools/__init__.py @@ -1,4 +1,9 @@ from .discovery import register_discovery_tools +from .health import register_health_tools from .search import register_search_tools -__all__ = ["register_search_tools", "register_discovery_tools"] +__all__ = [ + "register_search_tools", + "register_discovery_tools", + "register_health_tools", +] diff --git a/mnemosyne/mcp_server/tools/discovery.py b/mnemosyne/mcp_server/tools/discovery.py index aeadf5e..413c2f5 100644 --- a/mnemosyne/mcp_server/tools/discovery.py +++ b/mnemosyne/mcp_server/tools/discovery.py @@ -20,16 +20,21 @@ def _clamp(limit: int) -> int: def register_discovery_tools(mcp): @mcp.tool - async def list_libraries(limit: int = DEFAULT_LIMIT, offset: int = 0) -> dict[str, Any]: + async def list_libraries( + limit: int = DEFAULT_LIMIT, + offset: int = 0, + # System-injected; deliberately absent from the docstring. + workspace_id: str | None = None, + ) -> dict[str, Any]: """List Mnemosyne libraries. Each library has a content-aware library_type - (fiction, nonfiction, technical, music, film, art, journal) that drives - chunking, embedding, and re-ranking. Returns uid, name, library_type, - description for each library — use the uid or library_type to scope a - subsequent search. + (fiction, nonfiction, technical, music, film, art, journal, business, + finance) that drives chunking, embedding, and re-ranking. Returns uid, + name, library_type, description for each library — use the uid or + library_type to scope a subsequent search. """ with record_tool_call("list_libraries"): return await sync_to_async(_query_libraries, thread_sensitive=True)( - _clamp(limit), max(offset, 0) + _clamp(limit), max(offset, 0), workspace_id ) @mcp.tool @@ -37,6 +42,8 @@ def register_discovery_tools(mcp): library_uid: str | None = None, limit: int = DEFAULT_LIMIT, offset: int = 0, + # System-injected; deliberately absent from the docstring. + workspace_id: str | None = None, ) -> dict[str, Any]: """List collections, optionally filtered by parent library_uid. Collections group related items inside a library (e.g. a series of novels, @@ -45,7 +52,7 @@ def register_discovery_tools(mcp): """ with record_tool_call("list_collections"): return await sync_to_async(_query_collections, thread_sensitive=True)( - library_uid, _clamp(limit), max(offset, 0) + library_uid, _clamp(limit), max(offset, 0), workspace_id ) @mcp.tool @@ -54,6 +61,8 @@ def register_discovery_tools(mcp): library_uid: str | None = None, limit: int = DEFAULT_LIMIT, offset: int = 0, + # System-injected; deliberately absent from the docstring. + workspace_id: str | None = None, ) -> dict[str, Any]: """List items (the indexed documents/files), optionally filtered by collection_uid or library_uid. Returns uid, title, item_type, file_type, @@ -63,17 +72,27 @@ def register_discovery_tools(mcp): """ with record_tool_call("list_items"): return await sync_to_async(_query_items, thread_sensitive=True)( - collection_uid, library_uid, _clamp(limit), max(offset, 0) + collection_uid, library_uid, _clamp(limit), max(offset, 0), workspace_id ) -def _query_libraries(limit: int, offset: int) -> dict[str, Any]: +_WORKSPACE_SCOPE = ( + "($workspace_id IS NULL AND l.workspace_id IS NULL OR " + "l.workspace_id = $workspace_id)" +) + + +def _query_libraries( + limit: int, offset: int, workspace_id: str | None = None +) -> dict[str, Any]: from neomodel import db rows, _ = db.cypher_query( - "MATCH (l:Library) RETURN l.uid, l.name, l.library_type, l.description " + "MATCH (l:Library) " + f"WHERE {_WORKSPACE_SCOPE} " + "RETURN l.uid, l.name, l.library_type, l.description " "ORDER BY l.name SKIP $offset LIMIT $limit", - {"offset": offset, "limit": limit}, + {"offset": offset, "limit": limit, "workspace_id": workspace_id}, ) return { "libraries": [ @@ -91,24 +110,33 @@ def _query_libraries(limit: int, offset: int) -> dict[str, Any]: def _query_collections( - library_uid: str | None, limit: int, offset: int + library_uid: str | None, limit: int, offset: int, + workspace_id: str | None = None, ) -> dict[str, Any]: from neomodel import db if library_uid: cypher = ( "MATCH (l:Library {uid: $library_uid})-[:CONTAINS]->(c:Collection) " + f"WHERE {_WORKSPACE_SCOPE} " "RETURN c.uid, c.name, c.description, l.uid, l.name " "ORDER BY c.name SKIP $offset LIMIT $limit" ) - params = {"library_uid": library_uid, "offset": offset, "limit": limit} + params = { + "library_uid": library_uid, "offset": offset, "limit": limit, + "workspace_id": workspace_id, + } else: cypher = ( "MATCH (l:Library)-[:CONTAINS]->(c:Collection) " + f"WHERE {_WORKSPACE_SCOPE} " "RETURN c.uid, c.name, c.description, l.uid, l.name " "ORDER BY l.name, c.name SKIP $offset LIMIT $limit" ) - params = {"offset": offset, "limit": limit} + params = { + "offset": offset, "limit": limit, + "workspace_id": workspace_id, + } rows, _ = db.cypher_query(cypher, params) return { @@ -132,11 +160,15 @@ def _query_items( library_uid: str | None, limit: int, offset: int, + workspace_id: str | None = None, ) -> dict[str, Any]: from neomodel import db - where = [] - params: dict[str, Any] = {"offset": offset, "limit": limit} + where = [_WORKSPACE_SCOPE] + params: dict[str, Any] = { + "offset": offset, "limit": limit, + "workspace_id": workspace_id, + } if collection_uid: where.append("c.uid = $collection_uid") params["collection_uid"] = collection_uid @@ -144,7 +176,7 @@ def _query_items( where.append("l.uid = $library_uid") params["library_uid"] = library_uid - where_clause = ("WHERE " + " AND ".join(where)) if where else "" + where_clause = "WHERE " + " AND ".join(where) cypher = ( "MATCH (l:Library)-[:CONTAINS]->(c:Collection)-[:CONTAINS]->(i:Item) " f"{where_clause} " diff --git a/mnemosyne/mcp_server/tools/health.py b/mnemosyne/mcp_server/tools/health.py new file mode 100644 index 0000000..905df65 --- /dev/null +++ b/mnemosyne/mcp_server/tools/health.py @@ -0,0 +1,123 @@ +"""Health-check MCP tool — used by Pallas/Daedalus health pollers. + +Per the Pallas health spec, returns one of: +- ok — all dependencies reachable +- degraded — non-critical dependency unhealthy (chat allowed) +- error — critical dependency unhealthy (chat blocked) + +The tool is intercepted by the FastMCP server and never invokes an LLM — +it executes synchronously against Neo4j, S3, and the embedding model +endpoint, and returns within the poller's timeout. +""" + +from __future__ import annotations + +import time +from typing import Any + +from asgiref.sync import sync_to_async + +from ..metrics import record_tool_call + + +def register_health_tools(mcp): + @mcp.tool + async def get_health() -> dict[str, Any]: + """Health check for Mnemosyne. + + Returns a status object compatible with the Pallas health spec: + {status: "ok"|"degraded"|"error", checks: {neo4j, s3, embedding}}. + """ + with record_tool_call("get_health"): + return await sync_to_async(_run_health_check, thread_sensitive=True)() + + +def _run_health_check() -> dict[str, Any]: + """Synchronous health check across Neo4j, S3, and embedding model.""" + checks: dict[str, dict[str, Any]] = {} + + checks["neo4j"] = _check_neo4j() + checks["s3"] = _check_s3() + checks["embedding"] = _check_embedding_model() + + # Aggregate status: error if any critical check failed; degraded if a + # non-critical check failed; ok otherwise. + if checks["neo4j"]["status"] == "error" or checks["s3"]["status"] == "error": + status = "error" + elif any(c["status"] != "ok" for c in checks.values()): + status = "degraded" + else: + status = "ok" + + return { + "status": status, + "checks": checks, + } + + +def _check_neo4j() -> dict[str, Any]: + start = time.time() + try: + from neomodel import db + + db.cypher_query("RETURN 1") + return { + "status": "ok", + "duration_ms": round((time.time() - start) * 1000, 1), + } + except Exception as exc: + return { + "status": "error", + "error": str(exc), + "duration_ms": round((time.time() - start) * 1000, 1), + } + + +def _check_s3() -> dict[str, Any]: + start = time.time() + try: + from django.core.files.storage import default_storage + + # `exists` on a path that won't exist is the cheapest round-trip + # we have. It returns False rather than raising on most backends. + default_storage.exists("__healthcheck__") + return { + "status": "ok", + "duration_ms": round((time.time() - start) * 1000, 1), + } + except Exception as exc: + return { + "status": "error", + "error": str(exc), + "duration_ms": round((time.time() - start) * 1000, 1), + } + + +def _check_embedding_model() -> dict[str, Any]: + """Soft check: confirm a system embedding model is configured. + + We don't hit the model endpoint here — that would burn GPU time on + every poll. The poller-level check is "is a model registered." + """ + start = time.time() + try: + from llm_manager.models import LLMModel + + model = LLMModel.get_system_embedding_model() + if model is None: + return { + "status": "degraded", + "error": "no system embedding model configured", + "duration_ms": round((time.time() - start) * 1000, 1), + } + return { + "status": "ok", + "model": model.name, + "duration_ms": round((time.time() - start) * 1000, 1), + } + except Exception as exc: + return { + "status": "degraded", + "error": str(exc), + "duration_ms": round((time.time() - start) * 1000, 1), + } diff --git a/mnemosyne/mcp_server/tools/search.py b/mnemosyne/mcp_server/tools/search.py index f75e8e5..d428eb9 100644 --- a/mnemosyne/mcp_server/tools/search.py +++ b/mnemosyne/mcp_server/tools/search.py @@ -27,13 +27,19 @@ def register_search_tools(mcp): rerank: bool = True, include_images: bool = True, search_types: list[str] | None = None, + # workspace_id is system-injected by Daedalus's chat path. It is + # intentionally absent from the docstring so the calling LLM is + # never told it exists. Whatever value the LLM produces here is + # overwritten by Daedalus before the call reaches Mnemosyne. + workspace_id: str | None = None, ctx: Context | None = None, ) -> dict[str, Any]: """Hybrid retrieval over Mnemosyne: vector + full-text + concept-graph candidates fused by RRF and optionally re-ranked by Synesis. Filters: library_uid (exact library), library_type (one of fiction, - nonfiction, technical, music, film, art, journal), or collection_uid. + nonfiction, technical, music, film, art, journal, business, finance), + or collection_uid. Set rerank=False to skip re-ranking. search_types defaults to all three. Returns ranked candidates with chunk_uid (use get_chunk for full text), @@ -49,6 +55,7 @@ def register_search_tools(mcp): library_uid=library_uid, library_type=library_type, collection_uid=collection_uid, + workspace_id=workspace_id, limit=limit, rerank=rerank, include_images=include_images, @@ -56,7 +63,12 @@ def register_search_tools(mcp): ) @mcp.tool - async def get_chunk(chunk_uid: str, ctx: Context | None = None) -> dict[str, Any]: + async def get_chunk( + chunk_uid: str, + # System-injected; deliberately absent from the docstring. + workspace_id: str | None = None, + ctx: Context | None = None, + ) -> dict[str, Any]: """Fetch the full text of a chunk by its uid (typically obtained from `search`). Returns the chunk text plus parent item context: chunk_uid, chunk_index, @@ -64,11 +76,13 @@ def register_search_tools(mcp): text_preview from `search` isn't enough. """ with record_tool_call("get_chunk"): - return await sync_to_async(_load_chunk, thread_sensitive=True)(chunk_uid) + return await sync_to_async(_load_chunk, thread_sensitive=True)( + chunk_uid, workspace_id + ) -def _run_search(*, user, query, library_uid, library_type, collection_uid, limit, - rerank, include_images, search_types) -> dict[str, Any]: +def _run_search(*, user, query, library_uid, library_type, collection_uid, + workspace_id, limit, rerank, include_images, search_types) -> dict[str, Any]: from library.services.search import SearchRequest, SearchService req = SearchRequest( @@ -76,6 +90,7 @@ def _run_search(*, user, query, library_uid, library_type, collection_uid, limit library_uid=library_uid, library_type=library_type, collection_uid=collection_uid, + workspace_id=workspace_id, search_types=search_types, limit=limit, vector_top_k=getattr(settings, "SEARCH_VECTOR_TOP_K", 50), @@ -97,15 +112,17 @@ def _run_search(*, user, query, library_uid, library_type, collection_uid, limit } -def _load_chunk(chunk_uid: str) -> dict[str, Any]: +def _load_chunk(chunk_uid: str, workspace_id: str | None = None) -> dict[str, Any]: from neomodel import db rows, _ = db.cypher_query( "MATCH (l:Library)-[:CONTAINS]->(:Collection)-[:CONTAINS]->" "(i:Item)-[:HAS_CHUNK]->(c:Chunk {uid: $uid}) " + "WHERE ($workspace_id IS NULL AND l.workspace_id IS NULL OR " + " l.workspace_id = $workspace_id) " "RETURN c.uid, c.chunk_index, c.chunk_s3_key, " "i.uid, i.title, l.library_type LIMIT 1", - {"uid": chunk_uid}, + {"uid": chunk_uid, "workspace_id": workspace_id}, ) if not rows: raise ValueError(f"Chunk not found: {chunk_uid}")