File size: 5,688 Bytes
60b97da
24e05bd
60b97da
 
 
 
24e05bd
 
60b97da
 
24e05bd
 
60b97da
24e05bd
 
 
 
60b97da
24e05bd
60b97da
 
24e05bd
 
 
 
 
 
 
 
60b97da
 
24e05bd
ee4d71c
60b97da
24e05bd
 
 
 
60b97da
 
24e05bd
60b97da
 
 
 
 
 
 
 
24e05bd
 
60b97da
24e05bd
 
 
 
 
 
 
 
 
 
 
 
60b97da
 
 
24e05bd
 
60b97da
 
24e05bd
 
 
 
 
60b97da
 
 
 
 
 
 
 
 
 
 
 
 
 
24e05bd
 
 
 
 
 
60b97da
 
24e05bd
 
 
 
 
 
 
 
 
 
 
 
60b97da
 
24e05bd
60b97da
 
24e05bd
 
 
 
 
60b97da
 
24e05bd
 
 
60b97da
 
24e05bd
35c1d2c
 
24e05bd
 
 
 
35c1d2c
60b97da
 
24e05bd
60b97da
 
24e05bd
60b97da
24e05bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
from pathlib import Path
from typing import List, Optional, Tuple
from uuid import uuid4

import numpy as np
from chromadb import Client
from chromadb.config import Settings


class ChromaVectorStore:
    def __init__(self, embedding_dim: int, index_path: str = None, persist: bool = True):
        self.embedding_dim = embedding_dim
        self.collection_name = os.getenv("CHROMA_COLLECTION", "repo_qa_chunks")
        self.upsert_batch_size = max(1, int(os.getenv("CHROMA_UPSERT_BATCH_SIZE", "64")))
        self.persist_path = os.getenv("CHROMA_PATH", index_path or "./data/chroma")
        self.persist = persist
        self.client = self._create_client()
        self.collection = self._ensure_collection()

    def _create_client(self):
        if self.persist:
            Path(self.persist_path).mkdir(parents=True, exist_ok=True)
            return Client(
                Settings(
                    is_persistent=True,
                    persist_directory=self.persist_path,
                    anonymized_telemetry=False,
                )
            )

        return Client(Settings(anonymized_telemetry=False))

    def _ensure_collection(self):
        return self.client.get_or_create_collection(
            name=self.collection_name,
            embedding_function=None,
            metadata={"hnsw:space": "cosine"},
        )

    def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[str]:
        if embeddings.size == 0:
            return []

        embeddings = embeddings.astype("float32")
        if embeddings.ndim == 1:
            embeddings = embeddings.reshape(1, -1)

        ids = [uuid4().hex for _ in metadata]
        total_points = len(ids)

        for start in range(0, total_points, self.upsert_batch_size):
            end = start + self.upsert_batch_size
            batch_ids = ids[start:end]
            batch_embeddings = embeddings[start:end].tolist()
            batch_metadata = []
            batch_documents = []

            for idx, meta in zip(batch_ids, metadata[start:end]):
                payload = self._sanitize_metadata(meta)
                payload["id"] = idx
                batch_metadata.append(payload)
                batch_documents.append(str(meta.get("content") or ""))

            batch_number = (start // self.upsert_batch_size) + 1
            total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size
            print(
                f"[chroma] Adding batch {batch_number}/{total_batches} "
                f"points={len(batch_ids)} progress={start}/{total_points}",
                flush=True,
            )
            self.collection.add(
                ids=batch_ids,
                embeddings=batch_embeddings,
                metadatas=batch_metadata,
                documents=batch_documents,
            )

        return ids

    def search(
        self,
        query_embedding: np.ndarray,
        k: int = 10,
        repo_filter: Optional[int] = None,
    ) -> List[Tuple[float, dict]]:
        if query_embedding.ndim == 1:
            query_embedding = query_embedding.reshape(1, -1)
        query_embedding = query_embedding.astype("float32")

        where = {"repository_id": repo_filter} if repo_filter is not None else None
        results = self.collection.query(
            query_embeddings=[query_embedding[0].tolist()],
            n_results=k,
            where=where,
            include=["documents", "metadatas", "distances"],
        )

        ids = (results.get("ids") or [[]])[0]
        documents = (results.get("documents") or [[]])[0]
        metadatas = (results.get("metadatas") or [[]])[0]
        distances = (results.get("distances") or [[]])[0]

        hits = []
        for idx, document, meta, distance in zip(ids, documents, metadatas, distances):
            payload = dict(meta or {})
            payload["id"] = payload.get("id") or idx
            payload["content"] = document or ""
            hits.append((self._distance_to_score(distance), payload))
        return hits

    def remove_repository(self, repo_id: int):
        self.collection.delete(where={"repository_id": repo_id})

    def clear(self):
        try:
            self.client.delete_collection(name=self.collection_name)
        except Exception:
            pass
        self.collection = self._ensure_collection()

    def save(self):
        persist = getattr(self.client, "persist", None)
        if callable(persist):
            persist()

    def load(self):
        self.collection = self._ensure_collection()

    def keep_alive(self) -> dict:
        heartbeat = getattr(self.client, "heartbeat", None)
        if callable(heartbeat):
            heartbeat()
        return self.get_stats()

    def get_stats(self) -> dict:
        return {
            "total_vectors": self.collection.count(),
            "embedding_dim": self.embedding_dim,
            "collection_name": self.collection_name,
            "persist_path": self.persist_path if self.persist else None,
        }

    @staticmethod
    def _sanitize_metadata(meta: dict) -> dict:
        sanitized = {}
        for key, value in meta.items():
            if key == "content":
                continue
            if value is None:
                sanitized[key] = ""
            elif isinstance(value, (str, int, float, bool)):
                sanitized[key] = value
            else:
                sanitized[key] = str(value)
        return sanitized

    @staticmethod
    def _distance_to_score(distance: float) -> float:
        if distance is None:
            return 0.0
        return max(0.0, min(1.0, 1.0 - float(distance)))