Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
-
from flask import Flask, request, render_template_string
|
2 |
from PIL import Image
|
3 |
import numpy as np
|
4 |
import io
|
5 |
import base64
|
|
|
|
|
|
|
6 |
|
7 |
app = Flask(__name__)
|
8 |
|
@@ -12,76 +15,199 @@ HTML_TEMPLATE = """
|
|
12 |
<head>
|
13 |
<title>ZeroIG Enhancement</title>
|
14 |
<style>
|
15 |
-
body { font-family: Arial, sans-serif; max-width:
|
16 |
.container { text-align: center; }
|
17 |
-
.upload-area { border: 2px dashed #ccc; padding: 40px; margin: 20px 0; }
|
18 |
.result { margin-top: 20px; }
|
19 |
-
|
|
|
|
|
|
|
|
|
20 |
</style>
|
21 |
</head>
|
22 |
<body>
|
23 |
<div class="container">
|
24 |
-
<h1>π ZeroIG: Low-Light Enhancement</h1>
|
25 |
-
<p>Upload a low-light image
|
26 |
|
27 |
<form method="post" enctype="multipart/form-data">
|
28 |
<div class="upload-area">
|
29 |
<input type="file" name="image" accept="image/*" required>
|
30 |
<br><br>
|
31 |
-
<button type="submit"
|
32 |
</div>
|
33 |
</form>
|
34 |
|
35 |
-
{% if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
<div class="result">
|
37 |
-
<h3>
|
38 |
-
<
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
</div>
|
42 |
{% endif %}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
</div>
|
44 |
</body>
|
45 |
</html>
|
46 |
"""
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
return Image.fromarray(enhanced)
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
@app.route('/', methods=['GET', 'POST'])
|
63 |
def index():
|
|
|
64 |
result_image = None
|
|
|
|
|
65 |
|
66 |
if request.method == 'POST':
|
67 |
try:
|
68 |
file = request.files['image']
|
69 |
if file:
|
|
|
|
|
70 |
# Open and process image
|
71 |
image = Image.open(file.stream).convert('RGB')
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
# Convert to base64
|
75 |
-
|
76 |
-
|
77 |
-
img_str = base64.b64encode(img_buffer.getvalue()).decode()
|
78 |
-
result_image = img_str
|
79 |
|
80 |
except Exception as e:
|
81 |
-
|
|
|
82 |
|
83 |
-
return render_template_string(HTML_TEMPLATE,
|
|
|
|
|
|
|
|
|
84 |
|
85 |
if __name__ == '__main__':
|
86 |
-
print("π Starting Flask app...")
|
87 |
app.run(host='0.0.0.0', port=7860)
|
|
|
1 |
+
from flask import Flask, request, render_template_string
|
2 |
from PIL import Image
|
3 |
import numpy as np
|
4 |
import io
|
5 |
import base64
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import os
|
9 |
|
10 |
app = Flask(__name__)
|
11 |
|
|
|
15 |
<head>
|
16 |
<title>ZeroIG Enhancement</title>
|
17 |
<style>
|
18 |
+
body { font-family: Arial, sans-serif; max-width: 1000px; margin: 0 auto; padding: 20px; }
|
19 |
.container { text-align: center; }
|
20 |
+
.upload-area { border: 2px dashed #ccc; padding: 40px; margin: 20px 0; border-radius: 10px; }
|
21 |
.result { margin-top: 20px; }
|
22 |
+
.comparison { display: flex; justify-content: space-around; flex-wrap: wrap; }
|
23 |
+
.image-container { margin: 10px; }
|
24 |
+
img { max-width: 400px; height: auto; border: 1px solid #ddd; border-radius: 5px; }
|
25 |
+
.status { color: green; font-weight: bold; margin: 10px 0; }
|
26 |
+
.error { color: red; }
|
27 |
</style>
|
28 |
</head>
|
29 |
<body>
|
30 |
<div class="container">
|
31 |
+
<h1>π ZeroIG: Zero-Shot Low-Light Enhancement</h1>
|
32 |
+
<p><strong>CVPR 2024</strong> - Upload a low-light image for professional enhancement!</p>
|
33 |
|
34 |
<form method="post" enctype="multipart/form-data">
|
35 |
<div class="upload-area">
|
36 |
<input type="file" name="image" accept="image/*" required>
|
37 |
<br><br>
|
38 |
+
<button type="submit" style="padding: 10px 20px; font-size: 16px;">π Enhance with ZeroIG</button>
|
39 |
</div>
|
40 |
</form>
|
41 |
|
42 |
+
{% if status %}
|
43 |
+
<div class="status">{{ status }}</div>
|
44 |
+
{% endif %}
|
45 |
+
|
46 |
+
{% if error %}
|
47 |
+
<div class="error">{{ error }}</div>
|
48 |
+
{% endif %}
|
49 |
+
|
50 |
+
{% if original_image and result_image %}
|
51 |
<div class="result">
|
52 |
+
<h3>Results:</h3>
|
53 |
+
<div class="comparison">
|
54 |
+
<div class="image-container">
|
55 |
+
<h4>Original (Low-light)</h4>
|
56 |
+
<img src="data:image/png;base64,{{ original_image }}" alt="Original">
|
57 |
+
</div>
|
58 |
+
<div class="image-container">
|
59 |
+
<h4>ZeroIG Enhanced</h4>
|
60 |
+
<img src="data:image/png;base64,{{ result_image }}" alt="Enhanced">
|
61 |
+
<br><br>
|
62 |
+
<a href="data:image/png;base64,{{ result_image }}" download="zeroig_enhanced.png"
|
63 |
+
style="background: #007bff; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
|
64 |
+
π₯ Download Enhanced Image
|
65 |
+
</a>
|
66 |
+
</div>
|
67 |
+
</div>
|
68 |
</div>
|
69 |
{% endif %}
|
70 |
+
|
71 |
+
<div style="margin-top: 40px; padding: 20px; background: #f8f9fa; border-radius: 10px;">
|
72 |
+
<h3>About ZeroIG</h3>
|
73 |
+
<p>Zero-shot illumination-guided joint denoising and adaptive enhancement for low-light images.</p>
|
74 |
+
<p><strong>Features:</strong> No training data required β’ Joint denoising & enhancement β’ Prevents over-enhancement</p>
|
75 |
+
<p>
|
76 |
+
π <a href="https://openaccess.thecvf.com/content/CVPR2024/papers/Shi_ZERO-IG_Zero-Shot_Illumination-Guided_Joint_Denoising_and_Adaptive_Enhancement_for_Low-Light_CVPR_2024_paper.pdf">Research Paper</a> |
|
77 |
+
π» <a href="https://github.com/Doyle59217/ZeroIG">Source Code</a>
|
78 |
+
</p>
|
79 |
+
</div>
|
80 |
</div>
|
81 |
</body>
|
82 |
</html>
|
83 |
"""
|
84 |
|
85 |
+
class ZeroIGProcessor:
|
86 |
+
def __init__(self):
|
87 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
88 |
+
self.model = self.load_model()
|
89 |
+
print(f"ZeroIG initialized on {self.device}")
|
90 |
+
|
91 |
+
def load_model(self):
|
92 |
+
try:
|
93 |
+
from model import Finetunemodel, Network
|
94 |
+
|
95 |
+
# Try to load trained weights
|
96 |
+
model_path = "./weights/model.pt"
|
97 |
+
if os.path.exists(model_path):
|
98 |
+
model = Finetunemodel(model_path)
|
99 |
+
print("β
Loaded trained ZeroIG model")
|
100 |
+
else:
|
101 |
+
model = Network()
|
102 |
+
print("β οΈ Using ZeroIG with random weights (train a model for better results)")
|
103 |
+
|
104 |
+
model.to(self.device)
|
105 |
+
model.eval()
|
106 |
+
return model
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
print(f"β Could not load ZeroIG model: {e}")
|
110 |
+
return None
|
111 |
+
|
112 |
+
def enhance_image(self, image):
|
113 |
+
"""Enhance image using ZeroIG model"""
|
114 |
+
try:
|
115 |
+
if self.model is None:
|
116 |
+
return self.simple_enhance(image), "Using simple enhancement (ZeroIG model not available)"
|
117 |
+
|
118 |
+
# Resize if too large to prevent memory issues
|
119 |
+
original_size = image.size
|
120 |
+
max_size = 1024
|
121 |
+
if max(image.size) > max_size:
|
122 |
+
ratio = max_size / max(image.size)
|
123 |
+
new_size = tuple(int(dim * ratio) for dim in image.size)
|
124 |
+
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
125 |
+
|
126 |
+
# Convert to tensor
|
127 |
+
transform = transforms.ToTensor()
|
128 |
+
input_tensor = transform(image).unsqueeze(0).to(self.device)
|
129 |
+
|
130 |
+
# Run ZeroIG model
|
131 |
+
with torch.no_grad():
|
132 |
+
if hasattr(self.model, 'enhance') and hasattr(self.model, 'denoise_1'):
|
133 |
+
# Finetunemodel - returns (enhanced, denoised)
|
134 |
+
enhanced, denoised = self.model(input_tensor)
|
135 |
+
result_tensor = denoised
|
136 |
+
status = "β
Enhanced with ZeroIG Finetuned model"
|
137 |
+
else:
|
138 |
+
# Network model - returns multiple outputs
|
139 |
+
outputs = self.model(input_tensor)
|
140 |
+
result_tensor = outputs[13] # H3 is the final denoised result
|
141 |
+
status = "β
Enhanced with ZeroIG Network model"
|
142 |
+
|
143 |
+
# Convert back to PIL
|
144 |
+
result_tensor = result_tensor.squeeze(0).cpu().clamp(0, 1)
|
145 |
+
enhanced_image = transforms.ToPILImage()(result_tensor)
|
146 |
+
|
147 |
+
# Resize back to original size if needed
|
148 |
+
if enhanced_image.size != original_size:
|
149 |
+
enhanced_image = enhanced_image.resize(original_size, Image.Resampling.LANCZOS)
|
150 |
+
|
151 |
+
return enhanced_image, status
|
152 |
+
|
153 |
+
except Exception as e:
|
154 |
+
print(f"ZeroIG enhancement error: {e}")
|
155 |
+
return self.simple_enhance(image), f"β οΈ ZeroIG failed, using simple enhancement: {str(e)}"
|
156 |
+
|
157 |
+
def simple_enhance(self, image):
|
158 |
+
"""Fallback simple enhancement"""
|
159 |
+
arr = np.array(image).astype(np.float32)
|
160 |
+
enhanced = np.clip(arr * 1.8, 0, 255).astype(np.uint8)
|
161 |
return Image.fromarray(enhanced)
|
162 |
+
|
163 |
+
# Initialize ZeroIG processor
|
164 |
+
print("π Loading ZeroIG processor...")
|
165 |
+
zeroig = ZeroIGProcessor()
|
166 |
+
|
167 |
+
def image_to_base64(image):
|
168 |
+
"""Convert PIL image to base64 string"""
|
169 |
+
img_buffer = io.BytesIO()
|
170 |
+
image.save(img_buffer, format='PNG')
|
171 |
+
img_str = base64.b64encode(img_buffer.getvalue()).decode()
|
172 |
+
return img_str
|
173 |
|
174 |
@app.route('/', methods=['GET', 'POST'])
|
175 |
def index():
|
176 |
+
original_image = None
|
177 |
result_image = None
|
178 |
+
status = None
|
179 |
+
error = None
|
180 |
|
181 |
if request.method == 'POST':
|
182 |
try:
|
183 |
file = request.files['image']
|
184 |
if file:
|
185 |
+
print(f"Processing uploaded image: {file.filename}")
|
186 |
+
|
187 |
# Open and process image
|
188 |
image = Image.open(file.stream).convert('RGB')
|
189 |
+
print(f"Image size: {image.size}")
|
190 |
+
|
191 |
+
# Store original for comparison
|
192 |
+
original_image = image_to_base64(image)
|
193 |
+
|
194 |
+
# Enhance with ZeroIG
|
195 |
+
enhanced_image, enhancement_status = zeroig.enhance_image(image)
|
196 |
|
197 |
+
# Convert result to base64
|
198 |
+
result_image = image_to_base64(enhanced_image)
|
199 |
+
status = enhancement_status
|
|
|
|
|
200 |
|
201 |
except Exception as e:
|
202 |
+
error = f"Error processing image: {str(e)}"
|
203 |
+
print(f"Error: {e}")
|
204 |
|
205 |
+
return render_template_string(HTML_TEMPLATE,
|
206 |
+
original_image=original_image,
|
207 |
+
result_image=result_image,
|
208 |
+
status=status,
|
209 |
+
error=error)
|
210 |
|
211 |
if __name__ == '__main__':
|
212 |
+
print("π Starting ZeroIG Flask app...")
|
213 |
app.run(host='0.0.0.0', port=7860)
|