""" server_app.py Standalone, editable FastAPI server with batching worker and compact embedding responses. Designed for manual deployment to HF Space or local testing (do not run here if model is too large). Features: - Batch worker that collects /embed requests and calls model.encode_text - Thread-safe future resolution using caller event loop and loop.call_soon_threadsafe(...) - Returns compressed `.npz` either as base64 JSON (`{"b64":...}`) or as raw `application/x-npz` bytes when `return_npz_raw=True`. - GZip middleware to reduce JSON transfer size. """ import io import os import time import base64 import threading import asyncio from typing import List, Optional import numpy as np import torch from fastapi import FastAPI, Response from fastapi.middleware.gzip import GZipMiddleware from pydantic import BaseModel from transformers import AutoTokenizer, AutoModel # ---------- Configuration ---------- MODEL_NAME = "jinaai/jina-embeddings-v4" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BATCH_SIZE = 16 BATCH_TIMEOUT = 0.01 # seconds # (Optional) caches - adjust for your deployment os.environ.setdefault("HF_HOME", "/tmp/huggingface") os.environ.setdefault("HF_HUB_CACHE", "/tmp/huggingface/hub") os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers") # ---------- Model + Tokenizer (lazy load if desired) ---------- # These lines will attempt to load the model if you run the server locally. # On HF Spaces, this is required; for local testing with a small model swap # the same code will work. tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True).to(DEVICE) model.eval() # ---------- Pydantic Schemas ---------- class EmbedRequest(BaseModel): text: str task: Optional[str] = "retrieval" prompt_name: Optional[str] = None truncate_dim: Optional[int] = None debug: bool = False # request compressed npz wrapped in base64 JSON return_npz: Optional[bool] = False # request raw application/x-npz bytes (faster, avoids base64) return_npz_raw: Optional[bool] = False # request raw contiguous bytes (arr.tobytes()) with shape/dtype headers return_raw_tobytes: Optional[bool] = False class EmbedResponse(BaseModel): embeddings: Optional[List[List[float]]] = None b64: Optional[str] = None class TokenizeRequest(BaseModel): text: str class TokenizeResponse(BaseModel): input_ids: List[int] class DecodeRequest(BaseModel): input_ids: List[int] class DecodeResponse(BaseModel): text: str class EmbedImageRequest(BaseModel): image: str task: Optional[str] = "retrieval" return_multivector: Optional[bool] = True truncate_dim: Optional[int] = None class EmbedImageResponse(BaseModel): embeddings: List[List[float]] # ---------- Queue / Worker infra ---------- request_queue = [] queue_lock = threading.Lock() new_item_event = threading.Event() def batch_worker() -> None: """Background thread that batches requests and runs model.encode_text.""" while True: new_item_event.wait() time.sleep(BATCH_TIMEOUT) with queue_lock: if not request_queue: new_item_event.clear() continue batch = request_queue[:BATCH_SIZE] del request_queue[: len(batch)] if not request_queue: new_item_event.clear() print(f"[{time.strftime('%H:%M:%S')}] Processing batch of {len(batch)} requests") # Separate queries and passages, carry their event loops query_reqs = [] # list of (req, fut, loop) passage_reqs = [] for item in batch: req, fut, loop = item["req"], item["future"], item["loop"] if (req.prompt_name or "").lower() == "query": query_reqs.append((req, fut, loop)) else: passage_reqs.append((req, fut, loop)) # Handle queries individually for req, fut, loop in query_reqs: start_t = time.perf_counter() try: with torch.no_grad(): outputs = model.encode_text( texts=[req.text], task=req.task, prompt_name="query", return_multivector=True, truncate_dim=req.truncate_dim, ) pooled = outputs[0].mean(dim=0).cpu().tolist() loop.call_soon_threadsafe(fut.set_result, {"embeddings": [pooled]}) except Exception as e: loop.call_soon_threadsafe(fut.set_exception, e) end_t = time.perf_counter() print(f"[{time.strftime('%H:%M:%S')}] Query embed took {end_t - start_t:.3f}s") # Handle passages: window, encode as a batch, then regroup if passage_reqs: start_t = time.perf_counter() all_windows = [] window_map = [] # (req_idx, win_id, has_stride) for idx, (req, fut, loop) in enumerate(passage_reqs): enc = tokenizer(req.text, add_special_tokens=False, return_tensors="pt") input_ids = enc["input_ids"].squeeze(0).to(DEVICE) total_tokens = input_ids.size(0) max_len = min(15_000, getattr(model.config, "max_position_embeddings", 4096)) stride = 50 pos = 0 win_id = 0 while pos < total_tokens: end = min(pos + max_len, total_tokens) window_ids = input_ids[pos:end] window_text = tokenizer.decode(window_ids, skip_special_tokens=True) all_windows.append(window_text) window_map.append((idx, win_id, pos > 0)) pos += max_len - stride win_id += 1 print(f"[{time.strftime('%H:%M:%S')}] Encoding {len(all_windows)} passage windows") # Encode all windows in one call try: with torch.no_grad(): if DEVICE == "cuda": with torch.autocast(device_type=DEVICE, dtype=torch.float16): outputs = model.encode_text( texts=all_windows, task="retrieval", prompt_name="passage", return_multivector=True ) else: outputs = model.encode_text( texts=all_windows, task="retrieval", prompt_name="passage", return_multivector=True ) except Exception as e: # propagate exceptions back to all futures for _, fut, loop in passage_reqs: loop.call_soon_threadsafe(fut.set_exception, e) continue # regroup outputs per request passage_embeds = [[] for _ in passage_reqs] for out, (req_idx, win_id, has_stride) in zip(outputs, window_map): emb = out.cpu() if has_stride: emb = emb[50:] passage_embeds[req_idx].append(emb) # deliver results for (req, fut, loop), embeds in zip(passage_reqs, passage_embeds): if not embeds: loop.call_soon_threadsafe(fut.set_result, {"embeddings": []}) continue full_tensor = torch.cat(embeds, dim=0).cpu() try: arr = full_tensor.numpy().astype(np.float32) except Exception: loop.call_soon_threadsafe(fut.set_result, {"embeddings": full_tensor.tolist()}) continue BYTES_THRESHOLD = 200_000 nbytes = arr.nbytes if getattr(req, "return_npz", False) or nbytes > BYTES_THRESHOLD: # measure serialization time to npz and payload size t_save0 = time.perf_counter() buf = io.BytesIO() # save positional to produce 'arr_0' key for backward compatibility np.savez_compressed(buf, arr) buf.seek(0) npz_bytes = buf.read() t_save1 = time.perf_counter() ser_time = t_save1 - t_save0 payload_len = len(npz_bytes) if getattr(req, "return_npz_raw", False): # Return raw binary (best performance) resp = Response(content=npz_bytes, media_type="application/x-npz") # attach diagnostic headers so clients can see server-side timings try: resp.headers["X-Serialize-Time"] = f"{ser_time:.3f}" resp.headers["X-Payload-Bytes"] = str(payload_len) resp.headers["X-Encode-Time"] = f"{(t_save0 - start_t):.3f}" except Exception: pass print(f"[{time.strftime('%H:%M:%S')}] Prepared raw npz payload bytes={payload_len} serialize_time={ser_time:.3f}s") loop.call_soon_threadsafe(fut.set_result, resp) elif getattr(req, "return_raw_tobytes", False): # Return raw contiguous bytes for zero-copy client parse raw_bytes = arr.tobytes() resp = Response(content=raw_bytes, media_type="application/octet-stream") try: # shape as rows,cols where arr is 2D shape = arr.shape resp.headers["X-Shape"] = ",".join(str(x) for x in shape) resp.headers["X-Dtype"] = str(arr.dtype) resp.headers["X-Serialize-Time"] = f"{ser_time:.3f}" resp.headers["X-Payload-Bytes"] = str(len(raw_bytes)) resp.headers["X-Encode-Time"] = f"{(t_save0 - start_t):.3f}" except Exception: pass print(f"[{time.strftime('%H:%M:%S')}] Prepared raw tobytes payload bytes={len(raw_bytes)} serialize_time={ser_time:.3f}s shape={getattr(arr,'shape',None)}") loop.call_soon_threadsafe(fut.set_result, resp) else: t_b640 = time.perf_counter() b64 = base64.b64encode(npz_bytes).decode("ascii") t_b641 = time.perf_counter() b64_time = t_b641 - t_b640 payload_chars = len(b64) print(f"[{time.strftime('%H:%M:%S')}] Prepared base64 payload chars={payload_chars} serialize_time={ser_time:.3f}s b64_time={b64_time:.3f}s") loop.call_soon_threadsafe(fut.set_result, {"b64": b64, "embeddings": []}) else: loop.call_soon_threadsafe(fut.set_result, {"embeddings": arr.tolist()}) end_t = time.perf_counter() print(f"[{time.strftime('%H:%M:%S')}] Passage batch took {end_t - start_t:.3f}s ({len(all_windows)} windows)") try: if DEVICE == "cuda": torch.cuda.empty_cache() try: torch.cuda.ipc_collect() except Exception: pass except Exception: pass # start worker thread threading.Thread(target=batch_worker, daemon=True).start() # ---------- FastAPI app ---------- app = FastAPI() app.add_middleware(GZipMiddleware, minimum_size=1000) @app.get("/") def ping(): return {"status": "ok", "message": "server alive"} @app.post("/embed") async def embed(req: EmbedRequest): print(f"[{time.strftime('%H:%M:%S')}] Received /embed len={len(req.text)}") loop = asyncio.get_running_loop() fut = loop.create_future() with queue_lock: request_queue.append({"req": req, "future": fut, "loop": loop}) new_item_event.set() t_wait0 = time.perf_counter() result = await fut t_wait1 = time.perf_counter() wait_time = t_wait1 - t_wait0 print(f"[{time.strftime('%H:%M:%S')}] Handler resumed after wait {wait_time:.3f}s for /embed (len={len(req.text)})") # If the worker returned a Response object (raw npz), return it directly if isinstance(result, Response): try: result.headers["X-Handler-Wait"] = f"{wait_time:.3f}" except Exception: pass return result try: # if returning JSON, include a hint for diagnostics if isinstance(result, dict): result["_handler_wait"] = wait_time except Exception: pass return result @app.post("/tokenize", response_model=TokenizeResponse) def tokenize(req: TokenizeRequest): enc = tokenizer(req.text, add_special_tokens=False, return_tensors="pt") return {"input_ids": enc["input_ids"].squeeze(0).tolist()} @app.post("/decode", response_model=DecodeResponse) def decode(req: DecodeRequest): text = tokenizer.decode(req.input_ids, skip_special_tokens=True) return {"text": text} @app.post("/embed_image", response_model=EmbedImageResponse) def embed_image(req: EmbedImageRequest): with torch.no_grad(): outputs = model.encode_image( images=[req.image], task=req.task, return_multivector=req.return_multivector, truncate_dim=req.truncate_dim, ) pooled = outputs[0].mean(dim=0).cpu() try: if DEVICE == "cuda": torch.cuda.empty_cache() try: torch.cuda.ipc_collect() except Exception: pass except Exception: pass return {"embeddings": [pooled.tolist()]}