| | from copy import deepcopy |
| | from typing import Dict, List, Any, Optional |
| |
|
| | import faiss |
| |
|
| | from langchain.docstore import InMemoryDocstore |
| | from langchain.embeddings import OpenAIEmbeddings |
| | from langchain.schema import Document |
| | from langchain.vectorstores import Chroma, FAISS |
| | from langchain.vectorstores.base import VectorStoreRetriever |
| |
|
| | from flows.base_flows import AtomicFlow |
| |
|
| |
|
| | class VectorStoreFlow(AtomicFlow): |
| | REQUIRED_KEYS_CONFIG = ["type", "api_keys"] |
| |
|
| | vector_db: VectorStoreRetriever |
| |
|
| | def __init__(self, vector_db, **kwargs): |
| | super().__init__(**kwargs) |
| | self.vector_db = vector_db |
| |
|
| | @classmethod |
| | def _set_up_retriever(cls, config: Dict[str, Any]) -> Dict[str, Any]: |
| | embeddings = OpenAIEmbeddings(openai_api_key=config["api_keys"]["openai"]) |
| | kwargs = {} |
| |
|
| | vs_type = config["type"] |
| |
|
| | if vs_type == "chroma": |
| | vectorstore = Chroma(config["name"], embedding_function=embeddings) |
| | elif vs_type == "faiss": |
| | index = faiss.IndexFlatL2(config.get("embedding_size", 1536)) |
| | vectorstore = FAISS( |
| | embedding_function=embeddings.embed_query, |
| | index=index, |
| | docstore=InMemoryDocstore({}), |
| | index_to_docstore_id={} |
| | ) |
| | else: |
| | raise NotImplementedError(f"Vector store '{vs_type}' not implemented") |
| |
|
| | kwargs["vector_db"] = vectorstore.as_retriever(**config.get("retriever_config", {})) |
| |
|
| | return kwargs |
| |
|
| | @classmethod |
| | def instantiate_from_config(cls, config: Dict[str, Any]): |
| | flow_config = deepcopy(config) |
| |
|
| | kwargs = {"flow_config": flow_config} |
| |
|
| | kwargs.update(cls._set_up_retriever(flow_config)) |
| |
|
| | return cls(**kwargs) |
| |
|
| | @staticmethod |
| | def package_documents(documents: List[str]) -> List[Document]: |
| | |
| | return [Document(page_content=doc, metadata={"": ""}) for doc in documents] |
| |
|
| | def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| | response = {} |
| |
|
| | operation = input_data["operation"] |
| | assert operation in ["write", "read"], f"Operation '{operation}' not supported" |
| |
|
| | content = input_data["content"] |
| | if operation == "read": |
| | assert isinstance(content, str), f"Content must be a string, got {type(content)}" |
| | query = content |
| | retrieved_documents = self.vector_db.get_relevant_documents(query) |
| | response["retrieved"] = [doc.page_content for doc in retrieved_documents] |
| | elif operation == "write": |
| | if isinstance(content, str): |
| | content = [content] |
| | assert isinstance(content, list), f"Content must be a list of strings, got {type(content)}" |
| | documents = content |
| | documents = self.package_documents(documents) |
| | self.vector_db.add_documents(documents) |
| | response["retrieved"] = "" |
| |
|
| | return response |
| |
|