Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torchvision | |
| from torchvision import transforms | |
| from PIL import Image | |
| import io | |
| # Define the function to load the model | |
| def load_model(model_path, device): | |
| weights = torchvision.models.DenseNet201_Weights.DEFAULT # best available weight | |
| model = torchvision.models.densenet201(weights=weights).to(device) | |
| model.classifier = torch.nn.Sequential( | |
| torch.nn.Dropout(p=0.2, inplace=True), | |
| torch.nn.Linear(in_features=1920, out_features=2, bias=True) | |
| ).to(device) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| # Define the function for preprocessing the image | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize(64), | |
| transforms.ToTensor(), | |
| ]) | |
| return transform(image) | |
| # Define the function for getting predictions | |
| def get_prediction(model, image, device): | |
| class_names = ['normal','pneumonia'] | |
| image = image.unsqueeze(0).to(device) # Add batch dimension and move to device | |
| with torch.no_grad(): | |
| pred_logits = model(image) | |
| pred_prob = torch.softmax(pred_logits, dim=1) | |
| pred_label = torch.argmax(pred_prob, dim=1) | |
| return class_names[pred_label.item()], pred_prob.max().item() | |
| # Streamlit app starts here | |
| st.title("Chest X-ray Pneumonia Checking App") | |
| uploaded_file = st.file_uploader("Upload an image of a chest x-ray", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Convert the file-like object to bytes, then open it with PIL | |
| image_bytes = uploaded_file.getvalue() | |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') # make it three channel like training set | |
| # Display the uploaded image | |
| st.image(image, caption='Uploaded Image.', use_column_width=True) | |
| # Predict button | |
| if st.button('Predict'): | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load the model | |
| model_path = 'densenetxray.pth' # Fixed model path | |
| model = load_model(model_path, device) | |
| # Preprocess the image and predict | |
| preprocessed_image = preprocess_image(image) | |
| prediction, probability = get_prediction(model, preprocessed_image, device) | |
| # Display the prediction | |
| st.write(f"Prediction: {prediction}, Probability: {probability:.3f}") | |