Update handler.py
Browse files- handler.py +20 -20
handler.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
-
from typing import Dict, List, Any
|
2 |
-
from transformers import pipeline
|
3 |
-
from PIL import Image
|
4 |
-
class EndpointHandler():
|
5 |
-
def __init__(self, path=""):
|
6 |
-
# Initialize the image segmentation pipeline
|
7 |
-
self.pipe = pipeline("image-segmentation", model=
|
8 |
-
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
9 |
-
# Extract the image path from the input data
|
10 |
-
image_path = data.get("image_path", "")
|
11 |
-
|
12 |
-
# Perform image segmentation
|
13 |
-
pillow_mask = self.pipe(image_path, return_mask=True) # outputs a pillow mask
|
14 |
-
pillow_image = self.pipe(image_path) # outputs the segmented image
|
15 |
-
|
16 |
-
# Save the segmented image at the root folder
|
17 |
-
output_image_path = "segmented_image.png"
|
18 |
-
pillow_image.save(output_image_path)
|
19 |
-
|
20 |
-
# Return the result as a list of dictionaries
|
21 |
return [{"image_path": output_image_path, "mask": pillow_mask}]
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from transformers import pipeline
|
3 |
+
from PIL import Image
|
4 |
+
class EndpointHandler():
|
5 |
+
def __init__(self, path="."):
|
6 |
+
# Initialize the image segmentation pipeline
|
7 |
+
self.pipe = pipeline("image-segmentation", model=path, trust_remote_code=True)
|
8 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
9 |
+
# Extract the image path from the input data
|
10 |
+
image_path = data.get("image_path", "")
|
11 |
+
|
12 |
+
# Perform image segmentation
|
13 |
+
pillow_mask = self.pipe(image_path, return_mask=True) # outputs a pillow mask
|
14 |
+
pillow_image = self.pipe(image_path) # outputs the segmented image
|
15 |
+
|
16 |
+
# Save the segmented image at the root folder
|
17 |
+
output_image_path = "segmented_image.png"
|
18 |
+
pillow_image.save(output_image_path)
|
19 |
+
|
20 |
+
# Return the result as a list of dictionaries
|
21 |
return [{"image_path": output_image_path, "mask": pillow_mask}]
|