|
|
import gradio as gr |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import os |
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
model = YOLO('best.pt') |
|
|
|
|
|
|
|
|
class_names = ['IHC', 'OHC-1', 'OHC-2', 'OHC-3'] |
|
|
colors = [ |
|
|
(255, 255, 255), |
|
|
(255, 0, 0), |
|
|
(0, 255, 0), |
|
|
(0, 0, 255) |
|
|
] |
|
|
color_codes = {name: color for name, color in zip(class_names, colors)} |
|
|
|
|
|
|
|
|
example_paths = [ |
|
|
'./examples/images/example (1).png', |
|
|
'./examples/images/example (2).png', |
|
|
'./examples/images/example (3).png', |
|
|
'./examples/images/example (4).png', |
|
|
'./examples/images/example (5).png', |
|
|
'./examples/images/example (6).png', |
|
|
'./examples/images/example (7).png', |
|
|
'./examples/images/example (8).png', |
|
|
'./examples/images/example (9).png', |
|
|
'./examples/images/example (10).png', |
|
|
] |
|
|
|
|
|
|
|
|
example_hashes = {} |
|
|
for path in example_paths: |
|
|
example_image = cv2.imread(path) |
|
|
if example_image is not None: |
|
|
hash_value = hash(example_image.tobytes()) |
|
|
example_hashes[hash_value] = path |
|
|
|
|
|
|
|
|
def draw_ground_truth(image, annotation_file): |
|
|
image_height, image_width = image.shape[:2] |
|
|
annotations = [] |
|
|
with open(annotation_file, 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split() |
|
|
if len(parts) == 5: |
|
|
cls_id, x_center, y_center, width, height = map(float, parts) |
|
|
annotations.append((int(cls_id), x_center, y_center, width, height)) |
|
|
image_gt = image.copy() |
|
|
for cls_id, x_center, y_center, width, height in annotations: |
|
|
x = int((x_center - width / 2) * image_width) |
|
|
y = int((y_center - height / 2) * image_height) |
|
|
w = int(width * image_width) |
|
|
h = int(height * image_height) |
|
|
color = colors[cls_id % len(colors)] |
|
|
cv2.rectangle(image_gt, (x, y), (x + w, y + h), color, 2) |
|
|
return image_gt |
|
|
|
|
|
|
|
|
def draw_predictions(image): |
|
|
image_pred = image.copy() |
|
|
results = model(image) |
|
|
boxes = results[0].boxes.xyxy.cpu().numpy() |
|
|
classes = results[0].boxes.cls.cpu().numpy() |
|
|
names = results[0].names |
|
|
for i in range(len(boxes)): |
|
|
box = boxes[i] |
|
|
class_id = int(classes[i]) |
|
|
class_name = names[class_id] |
|
|
color = color_codes.get(class_name, (255, 255, 255)) |
|
|
cv2.rectangle( |
|
|
image_pred, |
|
|
(int(box[0]), int(box[1])), |
|
|
(int(box[2]), int(box[3])), |
|
|
color, |
|
|
2 |
|
|
) |
|
|
return image_pred |
|
|
|
|
|
|
|
|
def predict(input_image): |
|
|
image = np.array(input_image) |
|
|
image_name = input_image.name if hasattr(input_image, 'name') else 'uploaded_image.png' |
|
|
image_basename = os.path.basename(image_name) |
|
|
annotation_name = os.path.splitext(image_basename)[0] + '.txt' |
|
|
annotation_path = f'./examples/labels/{annotation_name}' |
|
|
|
|
|
if os.path.exists(annotation_path): |
|
|
image_gt = draw_ground_truth(image, annotation_path) |
|
|
else: |
|
|
image_gt = image.copy() |
|
|
|
|
|
return image, image_gt |
|
|
|
|
|
|
|
|
def split_and_predict(input_image): |
|
|
image = np.array(input_image) |
|
|
h, w = image.shape[:2] |
|
|
splits = [ |
|
|
image[0:h//2, 0:w//2], |
|
|
image[0:h//2, w//2:w], |
|
|
image[h//2:h, 0:w//2], |
|
|
image[h//2:h, w//2:w] |
|
|
] |
|
|
|
|
|
predictions = [] |
|
|
for img in splits: |
|
|
img_pred = draw_predictions(img) |
|
|
predictions.append(img_pred) |
|
|
|
|
|
return predictions |
|
|
|
|
|
|
|
|
def predict_part(input_image): |
|
|
image = np.array(input_image) |
|
|
image_pred = draw_predictions(image) |
|
|
|
|
|
image_name = input_image.name if hasattr(input_image, 'name') else 'selected_part.png' |
|
|
image_basename = os.path.basename(image_name) |
|
|
annotation_name = os.path.splitext(image_basename)[0] + '.txt' |
|
|
annotation_path = f'./examples/labels/{annotation_name}' |
|
|
|
|
|
if os.path.exists(annotation_path): |
|
|
image_gt = draw_ground_truth(image, annotation_path) |
|
|
gt_visibility = gr.update(visible=True) |
|
|
else: |
|
|
image_gt = None |
|
|
gt_visibility = gr.update(visible=False) |
|
|
|
|
|
return image_pred, image_gt, gt_visibility |
|
|
|
|
|
|
|
|
legend_html = "<h3>Color Legend:</h3><div style='display: flex; align-items: center;'>" |
|
|
for name, color in zip(class_names, colors): |
|
|
color_rgb = f'rgb({color[0]},{color[1]},{color[2]})' |
|
|
legend_html += ( |
|
|
f"<div style='margin-right: 15px; display: flex; align-items: center;'>" |
|
|
f"<span style='color: {color_rgb}; font-size: 20px;'>█</span>" |
|
|
f"<span style='margin-left: 5px;'>{name}</span>" |
|
|
f"</div>" |
|
|
) |
|
|
legend_html += "</div>" |
|
|
|
|
|
|
|
|
with gr.Blocks() as interface: |
|
|
gr.Markdown("## Advanced Detection of Cochlear Hair Cells Using YOLOv10 in Auditory Diagnostics") |
|
|
|
|
|
|
|
|
|
|
|
gr.HTML(legend_html) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image(type="pil", label="Input Image") |
|
|
gr.Examples( |
|
|
examples=example_paths, |
|
|
inputs=input_image, |
|
|
label="Examples" |
|
|
) |
|
|
with gr.Column(): |
|
|
output_gt = gr.Image(type="numpy", label="Labeled Image") |
|
|
|
|
|
input_image.change( |
|
|
fn=predict, |
|
|
inputs=input_image, |
|
|
outputs=[input_image, output_gt], |
|
|
) |
|
|
|
|
|
split_button = gr.Button("Split Image and Show Predictions") |
|
|
with gr.Row(): |
|
|
output_pred1 = gr.Image(type="numpy", label="Prediction Part 1") |
|
|
output_pred2 = gr.Image(type="numpy", label="Prediction Part 2") |
|
|
with gr.Row(): |
|
|
output_pred3 = gr.Image(type="numpy", label="Prediction Part 3") |
|
|
output_pred4 = gr.Image(type="numpy", label="Prediction Part 4") |
|
|
|
|
|
split_button.click( |
|
|
fn=split_and_predict, |
|
|
inputs=input_image, |
|
|
outputs=[output_pred1, output_pred2, output_pred3, output_pred4], |
|
|
) |
|
|
|
|
|
selected_part = gr.Image(type="pil", label="Select Image Part for Detailed View") |
|
|
part_pred = gr.Image(type="numpy", label="Prediction on Selected Part") |
|
|
part_gt = gr.Image(type="numpy", label="Ground Truth on Selected Part", visible=False) |
|
|
|
|
|
selected_part.change( |
|
|
fn=predict_part, |
|
|
inputs=selected_part, |
|
|
outputs=[part_pred, part_gt, part_gt], |
|
|
) |
|
|
|
|
|
interface.launch() |