aina-bg-rmv / handler.py
udman99's picture
Update handler.py
9fdfc4a verified
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import base64
from io import BytesIO
class EndpointHandler():
def __init__(self, path="."):
# Initialize the image segmentation pipeline
self.pipe = pipeline("image-segmentation", model=path, trust_remote_code=True)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Extract the image path from the input data
image_path = data.pop("inputs",data)
# Perform image segmentation
pillow_mask = self.pipe(image_path, return_mask=True) # outputs a pillow mask
pillow_image = self.pipe(image_path) # outputs the segmented image
# Return the result as a list of dictionaries
return [{"image": pillow_image, "mask": pillow_mask}]