aina-bg-rmv / handler.py
udman99's picture
Upload folder using huggingface_hub
ea40a1d verified
raw
history blame
994 Bytes
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
class EndpointHandler():
def __init__(self, path=""):
# Initialize the image segmentation pipeline
self.pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Extract the image path from the input data
image_path = data.get("image_path", "")
# Perform image segmentation
pillow_mask = self.pipe(image_path, return_mask=True) # outputs a pillow mask
pillow_image = self.pipe(image_path) # outputs the segmented image
# Save the segmented image at the root folder
output_image_path = "segmented_image.png"
pillow_image.save(output_image_path)
# Return the result as a list of dictionaries
return [{"image_path": output_image_path, "mask": pillow_mask}]