udman99 commited on
Commit
dbf18a7
·
verified ·
1 Parent(s): eacc646

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +20 -20
handler.py CHANGED
@@ -1,21 +1,21 @@
1
- from typing import Dict, List, Any
2
- from transformers import pipeline
3
- from PIL import Image
4
- class EndpointHandler():
5
- def __init__(self, path=""):
6
- # Initialize the image segmentation pipeline
7
- self.pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
8
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
9
- # Extract the image path from the input data
10
- image_path = data.get("image_path", "")
11
-
12
- # Perform image segmentation
13
- pillow_mask = self.pipe(image_path, return_mask=True) # outputs a pillow mask
14
- pillow_image = self.pipe(image_path) # outputs the segmented image
15
-
16
- # Save the segmented image at the root folder
17
- output_image_path = "segmented_image.png"
18
- pillow_image.save(output_image_path)
19
-
20
- # Return the result as a list of dictionaries
21
  return [{"image_path": output_image_path, "mask": pillow_mask}]
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ class EndpointHandler():
5
+ def __init__(self, path="."):
6
+ # Initialize the image segmentation pipeline
7
+ self.pipe = pipeline("image-segmentation", model=path, trust_remote_code=True)
8
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
9
+ # Extract the image path from the input data
10
+ image_path = data.get("image_path", "")
11
+
12
+ # Perform image segmentation
13
+ pillow_mask = self.pipe(image_path, return_mask=True) # outputs a pillow mask
14
+ pillow_image = self.pipe(image_path) # outputs the segmented image
15
+
16
+ # Save the segmented image at the root folder
17
+ output_image_path = "segmented_image.png"
18
+ pillow_image.save(output_image_path)
19
+
20
+ # Return the result as a list of dictionaries
21
  return [{"image_path": output_image_path, "mask": pillow_mask}]