TK561 commited on
Commit
7d2a98b
·
1 Parent(s): 7ef81f1

fix: 実際の深度推定機能を実装

Browse files

- DepthAnything V2モデルの統合
- Base64画像入力のサポート
- API エンドポイント /api/predict の追加
- 必要な依存関係の追加
- メモリ管理の改善

Files changed (2) hide show
  1. app.py +86 -18
  2. requirements.txt +7 -1
app.py CHANGED
@@ -1,23 +1,91 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def depth_estimation(image):
4
- """最もシンプルな深度推定テスト"""
5
- if image is None:
6
- return None, None
 
 
 
 
 
 
 
7
 
8
- # まずは画像をそのまま返すテスト
9
- return image, image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # 最小限のGradio Interface
12
- demo = gr.Interface(
13
- fn=depth_estimation,
14
- inputs=gr.Image(type="pil"),
15
- outputs=[
16
- gr.Image(label="元画像"),
17
- gr.Image(label="深度マップ")
18
- ],
19
- title="深度推定 API",
20
- description="テスト中"
21
- )
22
 
23
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ import base64
8
+ import io
9
 
10
+ class DepthEstimationAPI:
11
+ def __init__(self):
12
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {self.device}")
14
+
15
+ model_name = "depth-anything/Depth-Anything-V2-Small-hf"
16
+ self.processor = AutoImageProcessor.from_pretrained(model_name)
17
+ self.model = AutoModelForDepthEstimation.from_pretrained(model_name)
18
+ self.model.to(self.device)
19
+ self.model.eval()
20
+ print("Model loaded successfully")
21
 
22
+ def predict(self, image_input):
23
+ """Process image and return depth map"""
24
+ try:
25
+ # Handle different input types
26
+ if isinstance(image_input, str):
27
+ # Base64 encoded image
28
+ if image_input.startswith('data:image'):
29
+ header, encoded = image_input.split(',', 1)
30
+ image_bytes = base64.b64decode(encoded)
31
+ image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
32
+ else:
33
+ # File path
34
+ image = Image.open(image_input).convert('RGB')
35
+ else:
36
+ # PIL Image
37
+ image = image_input.convert('RGB') if hasattr(image_input, 'convert') else image_input
38
+
39
+ # Process image
40
+ inputs = self.processor(images=image, return_tensors="pt")
41
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
42
+
43
+ with torch.no_grad():
44
+ outputs = self.model(**inputs)
45
+ depth = outputs.predicted_depth.squeeze().cpu().numpy()
46
+
47
+ # Create depth visualization
48
+ depth_normalized = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
49
+ depth_colored = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_VIRIDIS)
50
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
51
+ depth_image = Image.fromarray(depth_colored)
52
+
53
+ # Clean up
54
+ del inputs, outputs, depth, depth_normalized, depth_colored
55
+ if torch.cuda.is_available():
56
+ torch.cuda.empty_cache()
57
+
58
+ return [image, depth_image]
59
+
60
+ except Exception as e:
61
+ print(f"Error in prediction: {e}")
62
+ return [None, None]
63
 
64
+ # Initialize API
65
+ api = DepthEstimationAPI()
 
 
 
 
 
 
 
 
 
66
 
67
+ # Create Gradio interface with API support
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("# Depth Estimation API")
70
+ gr.Markdown("AI-powered depth estimation using DepthAnything V2")
71
+
72
+ with gr.Row():
73
+ with gr.Column():
74
+ input_image = gr.Image(type="pil", label="Upload Image")
75
+ submit_btn = gr.Button("Generate Depth Map", variant="primary")
76
+
77
+ with gr.Column():
78
+ output_original = gr.Image(type="pil", label="Original Image")
79
+ output_depth = gr.Image(type="pil", label="Depth Map")
80
+
81
+ # Define the API endpoint
82
+ submit_btn.click(
83
+ fn=api.predict,
84
+ inputs=input_image,
85
+ outputs=[output_original, output_depth],
86
+ api_name="predict" # This creates the /api/predict endpoint
87
+ )
88
+
89
+ # Launch with proper settings for Hugging Face Spaces
90
+ if __name__ == "__main__":
91
+ demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -1 +1,7 @@
1
- gradio
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ gradio
5
+ numpy
6
+ opencv-python-headless
7
+ Pillow