import bm25s from bm25s.hf import BM25HF import json from fastapi import FastAPI from pydantic import BaseModel from fastapi import HTTPException import os import traceback from fastapi.middleware.cors import CORSMiddleware token = os.environ["token_huggingface"] app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], ) class Item(BaseModel): id: int text: str # Load the index retriever = BM25HF.load_from_hub("anhdt-dsai-02/caption_1_2_3_4", load_corpus=True, token = token) def get_similar_images_based_on_caption(caption, retriever): docs, scores = retriever.retrieve(bm25s.tokenize(caption), k=4) return docs[0] @app.post("/retrieval") async def gate(caption): lst_ids = get_similar_images_based_on_caption(caption, retriever) items = [Item(**item) for item in lst_ids] return items