Spaces:
Running
Running
File size: 5,688 Bytes
60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd ee4d71c 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd 35c1d2c 24e05bd 35c1d2c 60b97da 24e05bd 60b97da 24e05bd 60b97da 24e05bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | import os
from pathlib import Path
from typing import List, Optional, Tuple
from uuid import uuid4
import numpy as np
from chromadb import Client
from chromadb.config import Settings
class ChromaVectorStore:
def __init__(self, embedding_dim: int, index_path: str = None, persist: bool = True):
self.embedding_dim = embedding_dim
self.collection_name = os.getenv("CHROMA_COLLECTION", "repo_qa_chunks")
self.upsert_batch_size = max(1, int(os.getenv("CHROMA_UPSERT_BATCH_SIZE", "64")))
self.persist_path = os.getenv("CHROMA_PATH", index_path or "./data/chroma")
self.persist = persist
self.client = self._create_client()
self.collection = self._ensure_collection()
def _create_client(self):
if self.persist:
Path(self.persist_path).mkdir(parents=True, exist_ok=True)
return Client(
Settings(
is_persistent=True,
persist_directory=self.persist_path,
anonymized_telemetry=False,
)
)
return Client(Settings(anonymized_telemetry=False))
def _ensure_collection(self):
return self.client.get_or_create_collection(
name=self.collection_name,
embedding_function=None,
metadata={"hnsw:space": "cosine"},
)
def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[str]:
if embeddings.size == 0:
return []
embeddings = embeddings.astype("float32")
if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)
ids = [uuid4().hex for _ in metadata]
total_points = len(ids)
for start in range(0, total_points, self.upsert_batch_size):
end = start + self.upsert_batch_size
batch_ids = ids[start:end]
batch_embeddings = embeddings[start:end].tolist()
batch_metadata = []
batch_documents = []
for idx, meta in zip(batch_ids, metadata[start:end]):
payload = self._sanitize_metadata(meta)
payload["id"] = idx
batch_metadata.append(payload)
batch_documents.append(str(meta.get("content") or ""))
batch_number = (start // self.upsert_batch_size) + 1
total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size
print(
f"[chroma] Adding batch {batch_number}/{total_batches} "
f"points={len(batch_ids)} progress={start}/{total_points}",
flush=True,
)
self.collection.add(
ids=batch_ids,
embeddings=batch_embeddings,
metadatas=batch_metadata,
documents=batch_documents,
)
return ids
def search(
self,
query_embedding: np.ndarray,
k: int = 10,
repo_filter: Optional[int] = None,
) -> List[Tuple[float, dict]]:
if query_embedding.ndim == 1:
query_embedding = query_embedding.reshape(1, -1)
query_embedding = query_embedding.astype("float32")
where = {"repository_id": repo_filter} if repo_filter is not None else None
results = self.collection.query(
query_embeddings=[query_embedding[0].tolist()],
n_results=k,
where=where,
include=["documents", "metadatas", "distances"],
)
ids = (results.get("ids") or [[]])[0]
documents = (results.get("documents") or [[]])[0]
metadatas = (results.get("metadatas") or [[]])[0]
distances = (results.get("distances") or [[]])[0]
hits = []
for idx, document, meta, distance in zip(ids, documents, metadatas, distances):
payload = dict(meta or {})
payload["id"] = payload.get("id") or idx
payload["content"] = document or ""
hits.append((self._distance_to_score(distance), payload))
return hits
def remove_repository(self, repo_id: int):
self.collection.delete(where={"repository_id": repo_id})
def clear(self):
try:
self.client.delete_collection(name=self.collection_name)
except Exception:
pass
self.collection = self._ensure_collection()
def save(self):
persist = getattr(self.client, "persist", None)
if callable(persist):
persist()
def load(self):
self.collection = self._ensure_collection()
def keep_alive(self) -> dict:
heartbeat = getattr(self.client, "heartbeat", None)
if callable(heartbeat):
heartbeat()
return self.get_stats()
def get_stats(self) -> dict:
return {
"total_vectors": self.collection.count(),
"embedding_dim": self.embedding_dim,
"collection_name": self.collection_name,
"persist_path": self.persist_path if self.persist else None,
}
@staticmethod
def _sanitize_metadata(meta: dict) -> dict:
sanitized = {}
for key, value in meta.items():
if key == "content":
continue
if value is None:
sanitized[key] = ""
elif isinstance(value, (str, int, float, bool)):
sanitized[key] = value
else:
sanitized[key] = str(value)
return sanitized
@staticmethod
def _distance_to_score(distance: float) -> float:
if distance is None:
return 0.0
return max(0.0, min(1.0, 1.0 - float(distance)))
|