Spaces:
Build error
Build error
import sys | |
sys.path.append('DenseMammogram') | |
import torch | |
from models import get_FRCNN_model, Bilateral_model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
frcnn_model = get_FRCNN_model().to(device) | |
bilat_model = Bilateral_model(frcnn_model).to(device) | |
FRCNN_PATH = 'pretrained_models/frcnn/frcnn_models/frcnn_model.pth' | |
BILAR_PATH = 'pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth' | |
frcnn_model.load_state_dict(torch.load(FRCNN_PATH, map_location=device)) | |
bilat_model.load_state_dict(torch.load(BILAR_PATH, map_location=device)) | |
import os | |
import torchvision.transforms as T | |
import cv2 | |
from tqdm import tqdm | |
import detection.transforms as transforms | |
from dataloaders import get_direction | |
def predict(left_file, right_file, threshold = 0.80, baseIsLeft = True): | |
model = bilat_model | |
with torch.no_grad(): | |
transform = T.Compose([T.ToPILImage(),T.ToTensor()]) | |
model.eval() | |
# First is left, then right | |
img1 = cv2.imread(left_file) | |
img1 = transform(img1) | |
img2 = cv2.imread(right_file) | |
img2 = transform(img2) | |
if baseIsLeft: | |
img1,_ = transforms.RandomHorizontalFlip(1.0)(img1) | |
else: | |
img2,_ = transforms.RandomHorizontalFlip(1.0)(img2) | |
images = [img1.to(device),img2.to(device)] | |
output = model([images])[0] | |
if baseIsLeft: | |
img1,output = transforms.RandomHorizontalFlip(1.0)(img1,output) | |
image = cv2.imread(left_file) | |
for b,s,l in zip(output['boxes'], output['scores'], output['labels']): | |
# Convert img1 tensor to numpy array | |
if l == 1 and s > threshold: | |
# Draw the bounding boxes | |
b = b.detach().cpu().numpy().astype(int) | |
# return image, b | |
cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 255, 0), 2) | |
# Print the % probability just above the box | |
cv2.putText(image, 'Cancer: '+str(round(round(s.item(), 2) * 100, 1)) + '%', (b[0], b[1] - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.6, (36,255,12), 6) | |
return image |