| from typing import Dict, List, Any | |
| from transformers import pipeline | |
| from PIL import Image | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| # Initialize the image segmentation pipeline | |
| self.pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True) | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| # Extract the image path from the input data | |
| image_path = data.get("image_path", "") | |
| # Perform image segmentation | |
| pillow_mask = self.pipe(image_path, return_mask=True) # outputs a pillow mask | |
| pillow_image = self.pipe(image_path) # outputs the segmented image | |
| # Save the segmented image at the root folder | |
| output_image_path = "segmented_image.png" | |
| pillow_image.save(output_image_path) | |
| # Return the result as a list of dictionaries | |
| return [{"image_path": output_image_path, "mask": pillow_mask}] |