jsamch's picture
New approach
f2b1904 verified
# app.py
import gradio as gr
from ultralytics import YOLO
from PIL import Image, ImageDraw, ImageFont
from huggingface_hub import hf_hub_download
import torch
import os
import requests # Keep for potential future use, though hf_hub_download handles model download
# --- Configuration ---
MODEL_REPO_ID = "biglam/historic-newspaper-illustrations-yolov11"
# Choose 'yolo11n.pt' (nano) or 'yolo11s.pt' (small)
MODEL_FILENAME = "yolo11n.pt" # Defaulting to the smaller nano model
# --- Model Loading ---
model = None # Initialize model variable
# Use try-except block for robust model loading
try:
# Step 1: Download the specific model weights file from Hugging Face Hub
print(f"Downloading model weights '{MODEL_FILENAME}' from '{MODEL_REPO_ID}'...")
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
print(f"Model weights downloaded to: {model_path}")
# Step 2: Load the YOLO model using the downloaded weights file path
model = YOLO(model_path)
print(f"Ultralytics YOLO model loaded successfully from '{model_path}'.")
# You can check the device it's running on (YOLO usually auto-detects)
device = next(model.parameters()).device
print(f"Model is running on device: {device}")
except Exception as e:
print(f"Error loading Ultralytics YOLO model: {e}")
# Ensure model remains None if loading fails
# --- Image Processing Function ---
def detect_illustrations(input_image: Image.Image) -> Image.Image:
"""
Detects illustrations in the input image using the loaded Ultralytics YOLO model
and draws bounding boxes around them.
Args:
input_image (PIL.Image.Image): The image uploaded by the user.
Returns:
PIL.Image.Image: The image with bounding boxes drawn around detected illustrations,
or the original image if the model failed to load or no objects are detected.
"""
if model is None:
print("Model not loaded. Returning original image.")
if input_image is None:
# Handle case where user clicks submit without uploading image
# Create a placeholder image or return None based on Gradio handling
placeholder = Image.new('RGB', (300, 100), color = 'white')
d = ImageDraw.Draw(placeholder)
try:
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
font = ImageFont.load_default()
d.text((10,10), "Error: Model not loaded & No image provided.", fill="red", font=font)
return placeholder
# If image exists but model failed, add error text to image
draw = ImageDraw.Draw(input_image)
try:
font = ImageFont.truetype("arial.ttf", 20) # Adjust font/size as needed
except IOError:
font = ImageFont.load_default()
draw.text((10, 10), "Error: Model could not be loaded.", fill="red", font=font)
return input_image
if input_image is None:
# Handle case where user clicks submit without uploading image after model is loaded
placeholder = Image.new('RGB', (300, 100), color = 'white')
d = ImageDraw.Draw(placeholder)
try:
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
font = ImageFont.load_default()
d.text((10,10), "Please upload an image.", fill="orange", font=font)
return placeholder
# Convert image to RGB if it's not already
if input_image.mode != "RGB":
input_image = input_image.convert("RGB")
# Perform object detection using the Ultralytics model
try:
# results is a list of Results objects
# Set confidence threshold if desired, e.g., model(input_image, conf=0.5)
results = model(input_image, verbose=False) # Set verbose=True for more detailed logs
print(f"Detection results obtained.") # Log results for debugging
except Exception as e:
print(f"Error during inference: {e}")
# Handle inference errors
output_image = input_image.copy()
draw = ImageDraw.Draw(output_image)
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
draw.text((10, 10), f"Error during detection: {e}", fill="red", font=font)
return output_image
# --- Draw Bounding Boxes ---
output_image = input_image.copy()
draw = ImageDraw.Draw(output_image)
# Define colors
label_colors = {"illustration": "red"}
default_color = "blue"
# Load a font for labels
try:
# Using a slightly larger font
font = ImageFont.truetype("arial.ttf", 18)
except IOError:
print("Arial font not found. Using Pillow's default font.")
font = ImageFont.load_default()
# Process results (Ultralytics returns results per image, here we have one)
if results and results[0].boxes:
boxes = results[0].boxes # Access the Boxes object
print(f"Found {len(boxes)} potential objects.")
for box in boxes:
# Extract coordinates (xyxy format)
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) # Get coordinates as integers
# Get confidence score and class ID
conf = box.conf[0].item() # Confidence score
cls = int(box.cls[0].item()) # Class ID
# Get the label name from the model's names dictionary
class_name = model.names[cls] if cls in model.names else f"Class_{cls}"
print(f"Detected '{class_name}' (ID: {cls}) with confidence {conf:.2f} at [{x1}, {y1}, {x2}, {y2}]")
# Choose color based on label
color = label_colors.get(class_name, default_color)
# Draw the bounding box
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)
# Prepare label text
label_text = f"{class_name}: {conf:.2f}"
# Calculate text size and position using textbbox
try:
text_bbox = draw.textbbox((0, 0), label_text, font=font) # Use (0,0) for size calc
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
except AttributeError: # Fallback for older Pillow versions
text_width, text_height = font.getsize(label_text)
# Draw background rectangle for text for better visibility
text_bg_y = y1 - text_height - 4 # Position above box
if text_bg_y < 0: text_bg_y = y1 + 2 # Adjust if too close to top edge
draw.rectangle(
[(x1, text_bg_y), (x1 + text_width + 4, text_bg_y + text_height + 2)],
fill=color
)
# Draw the label text
draw.text((x1 + 2, text_bg_y + 1), label_text, fill="white", font=font)
else:
print("No objects detected in the results.")
# Optionally add text indicating no detections
draw.text((10, 10), "No illustrations detected.", fill="orange", font=font)
return output_image
# --- Gradio Interface ---
# Define the input and output components
image_input = gr.Image(type="pil", label="Upload Newspaper Image")
image_output = gr.Image(type="pil", label="Detected Illustrations")
# Define title and description for the Gradio app
title = f"Historic Newspaper Illustration Detector ({MODEL_FILENAME})"
description = f"""
Upload an image of a historic newspaper page.
This app uses the `{MODEL_REPO_ID}` model ('{MODEL_FILENAME}' weights via Ultralytics YOLO)
to detect illustrations and draw bounding boxes around them.
Processing might take a moment depending on the image size and server load.
Model loading happens once when the Space starts.
"""
article = f"<p style='text-align: center'><a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank'>Model Card</a> | Powered by <a href='https://github.com/ultralytics/ultralytics' target='_blank'>Ultralytics YOLO</a></p>"
# Create the Gradio interface
iface = gr.Interface(
fn=detect_illustrations,
inputs=image_input,
outputs=image_output,
title=title,
description=description,
article=article,
examples=[
# Add relative paths if you include example images in your Space repo
# e.g., ["example1.jpg"]
],
allow_flagging='never'
)
# --- Launch the App ---
if __name__ == "__main__":
# Launch the app. share=True creates a public link (useful for testing locally)
# In Hugging Face Spaces, share=True is not needed.
iface.launch()