AI-RESEARCHER-2024's picture
Update app.py
6f04e08 verified
raw
history blame
6.51 kB
import gradio as gr
import cv2
import numpy as np
import os
from ultralytics import YOLO
# Load the trained model
model = YOLO('best.pt')
# Define class names and colors
class_names = ['IHC', 'OHC-1', 'OHC-2', 'OHC-3']
colors = [
(255, 255, 255), # IHC - White
(255, 0, 0), # OHC-1 - Red
(0, 255, 0), # OHC-2 - Green
(0, 0, 255) # OHC-3 - Blue
]
color_codes = {name: color for name, color in zip(class_names, colors)}
# List of example images
example_paths = [
'./examples/images/example (1).png',
'./examples/images/example (2).png',
'./examples/images/example (3).png',
'./examples/images/example (4).png',
'./examples/images/example (5).png',
'./examples/images/example (6).png',
'./examples/images/example (7).png',
'./examples/images/example (8).png',
'./examples/images/example (9).png',
'./examples/images/example (10).png',
]
# Precompute hashes for example images
example_hashes = {}
for path in example_paths:
example_image = cv2.imread(path)
if example_image is not None:
hash_value = hash(example_image.tobytes())
example_hashes[hash_value] = path
# Function to draw ground truth boxes
def draw_ground_truth(image, annotation_file):
image_height, image_width = image.shape[:2]
annotations = []
with open(annotation_file, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) == 5:
cls_id, x_center, y_center, width, height = map(float, parts)
annotations.append((int(cls_id), x_center, y_center, width, height))
image_gt = image.copy()
for cls_id, x_center, y_center, width, height in annotations:
x = int((x_center - width / 2) * image_width)
y = int((y_center - height / 2) * image_height)
w = int(width * image_width)
h = int(height * image_height)
color = colors[cls_id % len(colors)]
cv2.rectangle(image_gt, (x, y), (x + w, y + h), color, 2)
return image_gt
# Function to draw prediction boxes
def draw_predictions(image):
image_pred = image.copy()
results = model(image)
boxes = results[0].boxes.xyxy.cpu().numpy()
classes = results[0].boxes.cls.cpu().numpy()
names = results[0].names
for i in range(len(boxes)):
box = boxes[i]
class_id = int(classes[i])
class_name = names[class_id]
color = color_codes.get(class_name, (255, 255, 255))
cv2.rectangle(
image_pred,
(int(box[0]), int(box[1])),
(int(box[2]), int(box[3])),
color,
2
)
return image_pred
# Prediction function for Step 1
def predict(input_image):
image = np.array(input_image)
image_name = input_image.name if hasattr(input_image, 'name') else 'uploaded_image.png'
image_basename = os.path.basename(image_name)
annotation_name = os.path.splitext(image_basename)[0] + '.txt'
annotation_path = f'./examples/labels/{annotation_name}'
if os.path.exists(annotation_path):
image_gt = draw_ground_truth(image, annotation_path)
else:
image_gt = image.copy()
return image, image_gt
# Function for Step 2
def split_and_predict(input_image):
image = np.array(input_image)
h, w = image.shape[:2]
splits = [
image[0:h//2, 0:w//2],
image[0:h//2, w//2:w],
image[h//2:h, 0:w//2],
image[h//2:h, w//2:w]
]
predictions = []
for img in splits:
img_pred = draw_predictions(img)
predictions.append(img_pred)
return predictions
# Prediction function for Step 3
def predict_part(input_image):
image = np.array(input_image)
image_pred = draw_predictions(image)
image_name = input_image.name if hasattr(input_image, 'name') else 'selected_part.png'
image_basename = os.path.basename(image_name)
annotation_name = os.path.splitext(image_basename)[0] + '.txt'
annotation_path = f'./examples/labels/{annotation_name}'
if os.path.exists(annotation_path):
image_gt = draw_ground_truth(image, annotation_path)
gt_visibility = gr.update(visible=True)
else:
image_gt = None
gt_visibility = gr.update(visible=False)
return image_pred, image_gt, gt_visibility
# Create the HTML legend
legend_html = "<h3>Color Legend:</h3><div style='display: flex; align-items: center;'>"
for name, color in zip(class_names, colors):
color_rgb = f'rgb({color[0]},{color[1]},{color[2]})'
legend_html += (
f"<div style='margin-right: 15px; display: flex; align-items: center;'>"
f"<span style='color: {color_rgb}; font-size: 20px;'>&#9608;</span>"
f"<span style='margin-left: 5px;'>{name}</span>"
f"</div>"
)
legend_html += "</div>"
# Create Gradio interface
with gr.Blocks() as interface:
gr.Markdown("## Advanced Detection of Cochlear Hair Cells Using YOLOv10 in Auditory Diagnostics")
# Add the color legend
# gr.Markdown(legend_markdown)
gr.HTML(legend_html)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
gr.Examples(
examples=example_paths,
inputs=input_image,
label="Examples"
)
with gr.Column():
output_gt = gr.Image(type="numpy", label="Labeled Image")
input_image.change(
fn=predict,
inputs=input_image,
outputs=[input_image, output_gt],
)
split_button = gr.Button("Split Image and Show Predictions")
with gr.Row():
output_pred1 = gr.Image(type="numpy", label="Prediction Part 1")
output_pred2 = gr.Image(type="numpy", label="Prediction Part 2")
with gr.Row():
output_pred3 = gr.Image(type="numpy", label="Prediction Part 3")
output_pred4 = gr.Image(type="numpy", label="Prediction Part 4")
split_button.click(
fn=split_and_predict,
inputs=input_image,
outputs=[output_pred1, output_pred2, output_pred3, output_pred4],
)
selected_part = gr.Image(type="pil", label="Select Image Part for Detailed View")
part_pred = gr.Image(type="numpy", label="Prediction on Selected Part")
part_gt = gr.Image(type="numpy", label="Ground Truth on Selected Part", visible=False)
selected_part.change(
fn=predict_part,
inputs=selected_part,
outputs=[part_pred, part_gt, part_gt],
)
interface.launch()