import os os.system("pip install 'mmengine>=0.6.0'") os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'") os.system("pip install mmsegmentation") import gradio as gr import fnmatch import cv2 import numpy as np import torch from mmengine import Config from mmseg.apis import init_model, inference_model, show_result_pyplot from mmseg.apis import MMSegInferencer import PIL from mim import download import warnings warnings.filterwarnings("ignore") mmseg_models_list = MMSegInferencer.list_models('mmseg') path = "./checkpoint" if not os.path.exists(path): os.makedirs(path) def clear_folder(folder_path): import shutil for filename in os.listdir(folder_path): file_path = os.path.join(folder_path, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print(f"Failed to delete {file_path}. Reason: {e}") print(f"Clear {folder_path} successfully.") def save_image(img, img_path): # Convert PIL image to OpenCV image img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # Save OpenCV image cv2.imwrite(img_path, img) def download_test_image(): # Images torch.hub.download_url_to_file( 'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg', 'bus.jpg') torch.hub.download_url_to_file( 'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg', 'dogs.jpg') torch.hub.download_url_to_file( 'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg', 'zidane.jpg') def download_cfg_checkpoint_model_name(model_name): clear_folder("./checkpoint") download(package='mmsegmentation', configs=[model_name], dest_root='./checkpoint') # 定义推理函数 def predict(img, model_name): # 保存输入图片 img_path = 'input_image.png' save_image(img, img_path) download_cfg_checkpoint_model_name(model_name) config_path = [f for f in os.listdir(path) if fnmatch.fnmatch(f, "*.py")][0] config_path = path + "/" + config_path checkpoint_path = [f for f in os.listdir(path) if fnmatch.fnmatch(f, "*.pth")][0] checkpoint_path = path + "/" + checkpoint_path # 从配置文件和权重文件构建模型 device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' if device == 'cpu': config_path = Config.fromfile(config_path) # Remove pretrain model download for testing config_path.model.pretrained = None # Replace SyncBN with BN to inference on CPU norm_cfg = dict(type='BN', requires_grad=True) config_path.model.backbone.norm_cfg = norm_cfg config_path.model.decode_head.norm_cfg = norm_cfg config_path.model.auxiliary_head.norm_cfg = norm_cfg model = init_model(config_path, checkpoint_path, device=device) # 推理给定图像 result = inference_model(model, img_path) # 保存可视化结果 vis_image = show_result_pyplot(model, img_path, result, show=False) vis_image_path = 'output_image.png' cv2.imwrite(vis_image_path, vis_image) output_img = PIL.Image.open(vis_image_path) # 返回输出图片 return output_img download_test_image() # 定义输入和输出界面 inputs_img = gr.inputs.Image(type='pil', label="Input Image") model_list = gr.inputs.Dropdown(choices=[m for m in mmseg_models_list], label='Model') outputs_img = gr.outputs.Image(type='pil', label="Output Image") # 启动界面 title = "MMSegmentation segmentation web demo" description = "
" \ "

MMSegmentation MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 OpenMMLab 项目的一部分。" \ "OpenMMLab Semantic Segmentation Toolbox and Benchmark..

" article = "

MMSegmentation

" \ "

gradio build by gatilin

" examples = [["bus.jpg", "deeplabv3_r101-d8_4xb2-40k_cityscapes-512x1024"], ["dogs.jpg", "pspnet_r50-d8_4xb2-40k_cityscapes-512x1024"], ["zidane.jpg", "fcn_r101-d8_4xb4-80k_ade20k-512x512"] ] gr.Interface(fn=predict, inputs=[inputs_img, model_list], outputs=outputs_img, examples=examples, title=title, description=description, article=article).launch()