Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
from ultralytics import YOLO | |
from PIL import Image, ImageDraw, ImageFont | |
from huggingface_hub import hf_hub_download | |
import torch | |
import os | |
import requests # Keep for potential future use, though hf_hub_download handles model download | |
# --- Configuration --- | |
MODEL_REPO_ID = "biglam/historic-newspaper-illustrations-yolov11" | |
# Choose 'yolo11n.pt' (nano) or 'yolo11s.pt' (small) | |
MODEL_FILENAME = "yolo11n.pt" # Defaulting to the smaller nano model | |
# --- Model Loading --- | |
model = None # Initialize model variable | |
# Use try-except block for robust model loading | |
try: | |
# Step 1: Download the specific model weights file from Hugging Face Hub | |
print(f"Downloading model weights '{MODEL_FILENAME}' from '{MODEL_REPO_ID}'...") | |
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME) | |
print(f"Model weights downloaded to: {model_path}") | |
# Step 2: Load the YOLO model using the downloaded weights file path | |
model = YOLO(model_path) | |
print(f"Ultralytics YOLO model loaded successfully from '{model_path}'.") | |
# You can check the device it's running on (YOLO usually auto-detects) | |
device = next(model.parameters()).device | |
print(f"Model is running on device: {device}") | |
except Exception as e: | |
print(f"Error loading Ultralytics YOLO model: {e}") | |
# Ensure model remains None if loading fails | |
# --- Image Processing Function --- | |
def detect_illustrations(input_image: Image.Image) -> Image.Image: | |
""" | |
Detects illustrations in the input image using the loaded Ultralytics YOLO model | |
and draws bounding boxes around them. | |
Args: | |
input_image (PIL.Image.Image): The image uploaded by the user. | |
Returns: | |
PIL.Image.Image: The image with bounding boxes drawn around detected illustrations, | |
or the original image if the model failed to load or no objects are detected. | |
""" | |
if model is None: | |
print("Model not loaded. Returning original image.") | |
if input_image is None: | |
# Handle case where user clicks submit without uploading image | |
# Create a placeholder image or return None based on Gradio handling | |
placeholder = Image.new('RGB', (300, 100), color = 'white') | |
d = ImageDraw.Draw(placeholder) | |
try: | |
font = ImageFont.truetype("arial.ttf", 15) | |
except IOError: | |
font = ImageFont.load_default() | |
d.text((10,10), "Error: Model not loaded & No image provided.", fill="red", font=font) | |
return placeholder | |
# If image exists but model failed, add error text to image | |
draw = ImageDraw.Draw(input_image) | |
try: | |
font = ImageFont.truetype("arial.ttf", 20) # Adjust font/size as needed | |
except IOError: | |
font = ImageFont.load_default() | |
draw.text((10, 10), "Error: Model could not be loaded.", fill="red", font=font) | |
return input_image | |
if input_image is None: | |
# Handle case where user clicks submit without uploading image after model is loaded | |
placeholder = Image.new('RGB', (300, 100), color = 'white') | |
d = ImageDraw.Draw(placeholder) | |
try: | |
font = ImageFont.truetype("arial.ttf", 15) | |
except IOError: | |
font = ImageFont.load_default() | |
d.text((10,10), "Please upload an image.", fill="orange", font=font) | |
return placeholder | |
# Convert image to RGB if it's not already | |
if input_image.mode != "RGB": | |
input_image = input_image.convert("RGB") | |
# Perform object detection using the Ultralytics model | |
try: | |
# results is a list of Results objects | |
# Set confidence threshold if desired, e.g., model(input_image, conf=0.5) | |
results = model(input_image, verbose=False) # Set verbose=True for more detailed logs | |
print(f"Detection results obtained.") # Log results for debugging | |
except Exception as e: | |
print(f"Error during inference: {e}") | |
# Handle inference errors | |
output_image = input_image.copy() | |
draw = ImageDraw.Draw(output_image) | |
try: | |
font = ImageFont.truetype("arial.ttf", 20) | |
except IOError: | |
font = ImageFont.load_default() | |
draw.text((10, 10), f"Error during detection: {e}", fill="red", font=font) | |
return output_image | |
# --- Draw Bounding Boxes --- | |
output_image = input_image.copy() | |
draw = ImageDraw.Draw(output_image) | |
# Define colors | |
label_colors = {"illustration": "red"} | |
default_color = "blue" | |
# Load a font for labels | |
try: | |
# Using a slightly larger font | |
font = ImageFont.truetype("arial.ttf", 18) | |
except IOError: | |
print("Arial font not found. Using Pillow's default font.") | |
font = ImageFont.load_default() | |
# Process results (Ultralytics returns results per image, here we have one) | |
if results and results[0].boxes: | |
boxes = results[0].boxes # Access the Boxes object | |
print(f"Found {len(boxes)} potential objects.") | |
for box in boxes: | |
# Extract coordinates (xyxy format) | |
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) # Get coordinates as integers | |
# Get confidence score and class ID | |
conf = box.conf[0].item() # Confidence score | |
cls = int(box.cls[0].item()) # Class ID | |
# Get the label name from the model's names dictionary | |
class_name = model.names[cls] if cls in model.names else f"Class_{cls}" | |
print(f"Detected '{class_name}' (ID: {cls}) with confidence {conf:.2f} at [{x1}, {y1}, {x2}, {y2}]") | |
# Choose color based on label | |
color = label_colors.get(class_name, default_color) | |
# Draw the bounding box | |
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3) | |
# Prepare label text | |
label_text = f"{class_name}: {conf:.2f}" | |
# Calculate text size and position using textbbox | |
try: | |
text_bbox = draw.textbbox((0, 0), label_text, font=font) # Use (0,0) for size calc | |
text_width = text_bbox[2] - text_bbox[0] | |
text_height = text_bbox[3] - text_bbox[1] | |
except AttributeError: # Fallback for older Pillow versions | |
text_width, text_height = font.getsize(label_text) | |
# Draw background rectangle for text for better visibility | |
text_bg_y = y1 - text_height - 4 # Position above box | |
if text_bg_y < 0: text_bg_y = y1 + 2 # Adjust if too close to top edge | |
draw.rectangle( | |
[(x1, text_bg_y), (x1 + text_width + 4, text_bg_y + text_height + 2)], | |
fill=color | |
) | |
# Draw the label text | |
draw.text((x1 + 2, text_bg_y + 1), label_text, fill="white", font=font) | |
else: | |
print("No objects detected in the results.") | |
# Optionally add text indicating no detections | |
draw.text((10, 10), "No illustrations detected.", fill="orange", font=font) | |
return output_image | |
# --- Gradio Interface --- | |
# Define the input and output components | |
image_input = gr.Image(type="pil", label="Upload Newspaper Image") | |
image_output = gr.Image(type="pil", label="Detected Illustrations") | |
# Define title and description for the Gradio app | |
title = f"Historic Newspaper Illustration Detector ({MODEL_FILENAME})" | |
description = f""" | |
Upload an image of a historic newspaper page. | |
This app uses the `{MODEL_REPO_ID}` model ('{MODEL_FILENAME}' weights via Ultralytics YOLO) | |
to detect illustrations and draw bounding boxes around them. | |
Processing might take a moment depending on the image size and server load. | |
Model loading happens once when the Space starts. | |
""" | |
article = f"<p style='text-align: center'><a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank'>Model Card</a> | Powered by <a href='https://github.com/ultralytics/ultralytics' target='_blank'>Ultralytics YOLO</a></p>" | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=detect_illustrations, | |
inputs=image_input, | |
outputs=image_output, | |
title=title, | |
description=description, | |
article=article, | |
examples=[ | |
# Add relative paths if you include example images in your Space repo | |
# e.g., ["example1.jpg"] | |
], | |
allow_flagging='never' | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
# Launch the app. share=True creates a public link (useful for testing locally) | |
# In Hugging Face Spaces, share=True is not needed. | |
iface.launch() | |