trinera-pest-detector / yolo_infer.py
S1-1IVAM's picture
Update yolo_infer.py
a5c85d7 verified
from pathlib import Path
import cv2
import numpy as np
from ultralytics import YOLO
import tempfile
# Load your trained YOLOv5 model
model = YOLO("best.pt") # ensure best.pt is in the app folder
def detect_image(image: np.ndarray) -> np.ndarray:
"""
Detect objects in an uploaded image (NumPy array) and return the plotted result.
"""
# Run YOLOv5 inference
results = model(image)
# results[0].plot() returns a NumPy array with detections drawn
return results[0].plot()
def detect_video(video_file: str) -> str:
"""
Detect objects in a video and return the path to the processed video.
"""
# Create a temporary output file
tmp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
raise ValueError("Cannot open uploaded video.")
# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) or 20.0
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(tmp_output, fourcc, fps, (width, height))
while True:
ret, frame = cap.read()
if not ret:
break
# Run YOLOv5 inference on each frame
results = model(frame)
out.write(results[0].plot())
cap.release()
out.release()
return tmp_output