feat: Add RAG service and modelfile
This commit is contained in:
37
rag_service/rag_api.py
Normal file
37
rag_service/rag_api.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
|
||||
# Configuration
|
||||
PERSIST_DIRECTORY = "/data/db"
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI()
|
||||
|
||||
# Load the vector store
|
||||
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
||||
db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings)
|
||||
retriever = db.as_retriever()
|
||||
|
||||
class RetrieveRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
class RetrieveResponse(BaseModel):
|
||||
context: str
|
||||
|
||||
@app.post("/retrieve", response_model=RetrieveResponse)
|
||||
async def retrieve_context(request: RetrieveRequest):
|
||||
"""
|
||||
Retrieves context from the vector store for a given query.
|
||||
"""
|
||||
try:
|
||||
docs = retriever.get_relevant_documents(request.query)
|
||||
context = "\n\n".join([doc.page_content for doc in docs])
|
||||
return RetrieveResponse(context=context)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user