syedaoon commited on
Commit
d59f3c4
Β·
verified Β·
1 Parent(s): 5de8ab6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -33
app.py CHANGED
@@ -1,8 +1,11 @@
1
- from flask import Flask, request, render_template_string, send_file
2
  from PIL import Image
3
  import numpy as np
4
  import io
5
  import base64
 
 
 
6
 
7
  app = Flask(__name__)
8
 
@@ -12,76 +15,199 @@ HTML_TEMPLATE = """
12
  <head>
13
  <title>ZeroIG Enhancement</title>
14
  <style>
15
- body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
16
  .container { text-align: center; }
17
- .upload-area { border: 2px dashed #ccc; padding: 40px; margin: 20px 0; }
18
  .result { margin-top: 20px; }
19
- img { max-width: 100%; height: auto; }
 
 
 
 
20
  </style>
21
  </head>
22
  <body>
23
  <div class="container">
24
- <h1>🌟 ZeroIG: Low-Light Enhancement</h1>
25
- <p>Upload a low-light image to enhance it!</p>
26
 
27
  <form method="post" enctype="multipart/form-data">
28
  <div class="upload-area">
29
  <input type="file" name="image" accept="image/*" required>
30
  <br><br>
31
- <button type="submit">Enhance Image</button>
32
  </div>
33
  </form>
34
 
35
- {% if result_image %}
 
 
 
 
 
 
 
 
36
  <div class="result">
37
- <h3>Enhanced Image:</h3>
38
- <img src="data:image/png;base64,{{ result_image }}" alt="Enhanced">
39
- <br><br>
40
- <a href="data:image/png;base64,{{ result_image }}" download="enhanced.png">Download Enhanced Image</a>
 
 
 
 
 
 
 
 
 
 
 
 
41
  </div>
42
  {% endif %}
 
 
 
 
 
 
 
 
 
 
43
  </div>
44
  </body>
45
  </html>
46
  """
47
 
48
- def enhance_image(image):
49
- """Simple image enhancement"""
50
- try:
51
- # Convert to numpy array
52
- arr = np.array(image)
53
-
54
- # Simple brightness enhancement
55
- enhanced = np.clip(arr.astype(np.float32) * 1.8, 0, 255).astype(np.uint8)
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return Image.fromarray(enhanced)
58
- except Exception as e:
59
- print(f"Enhancement error: {e}")
60
- return image
 
 
 
 
 
 
 
 
61
 
62
  @app.route('/', methods=['GET', 'POST'])
63
  def index():
 
64
  result_image = None
 
 
65
 
66
  if request.method == 'POST':
67
  try:
68
  file = request.files['image']
69
  if file:
 
 
70
  # Open and process image
71
  image = Image.open(file.stream).convert('RGB')
72
- enhanced = enhance_image(image)
 
 
 
 
 
 
73
 
74
- # Convert to base64 for display
75
- img_buffer = io.BytesIO()
76
- enhanced.save(img_buffer, format='PNG')
77
- img_str = base64.b64encode(img_buffer.getvalue()).decode()
78
- result_image = img_str
79
 
80
  except Exception as e:
81
- print(f"Error processing image: {e}")
 
82
 
83
- return render_template_string(HTML_TEMPLATE, result_image=result_image)
 
 
 
 
84
 
85
  if __name__ == '__main__':
86
- print("πŸš€ Starting Flask app...")
87
  app.run(host='0.0.0.0', port=7860)
 
1
+ from flask import Flask, request, render_template_string
2
  from PIL import Image
3
  import numpy as np
4
  import io
5
  import base64
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ import os
9
 
10
  app = Flask(__name__)
11
 
 
15
  <head>
16
  <title>ZeroIG Enhancement</title>
17
  <style>
18
+ body { font-family: Arial, sans-serif; max-width: 1000px; margin: 0 auto; padding: 20px; }
19
  .container { text-align: center; }
20
+ .upload-area { border: 2px dashed #ccc; padding: 40px; margin: 20px 0; border-radius: 10px; }
21
  .result { margin-top: 20px; }
22
+ .comparison { display: flex; justify-content: space-around; flex-wrap: wrap; }
23
+ .image-container { margin: 10px; }
24
+ img { max-width: 400px; height: auto; border: 1px solid #ddd; border-radius: 5px; }
25
+ .status { color: green; font-weight: bold; margin: 10px 0; }
26
+ .error { color: red; }
27
  </style>
28
  </head>
29
  <body>
30
  <div class="container">
31
+ <h1>🌟 ZeroIG: Zero-Shot Low-Light Enhancement</h1>
32
+ <p><strong>CVPR 2024</strong> - Upload a low-light image for professional enhancement!</p>
33
 
34
  <form method="post" enctype="multipart/form-data">
35
  <div class="upload-area">
36
  <input type="file" name="image" accept="image/*" required>
37
  <br><br>
38
+ <button type="submit" style="padding: 10px 20px; font-size: 16px;">πŸš€ Enhance with ZeroIG</button>
39
  </div>
40
  </form>
41
 
42
+ {% if status %}
43
+ <div class="status">{{ status }}</div>
44
+ {% endif %}
45
+
46
+ {% if error %}
47
+ <div class="error">{{ error }}</div>
48
+ {% endif %}
49
+
50
+ {% if original_image and result_image %}
51
  <div class="result">
52
+ <h3>Results:</h3>
53
+ <div class="comparison">
54
+ <div class="image-container">
55
+ <h4>Original (Low-light)</h4>
56
+ <img src="data:image/png;base64,{{ original_image }}" alt="Original">
57
+ </div>
58
+ <div class="image-container">
59
+ <h4>ZeroIG Enhanced</h4>
60
+ <img src="data:image/png;base64,{{ result_image }}" alt="Enhanced">
61
+ <br><br>
62
+ <a href="data:image/png;base64,{{ result_image }}" download="zeroig_enhanced.png"
63
+ style="background: #007bff; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
64
+ πŸ“₯ Download Enhanced Image
65
+ </a>
66
+ </div>
67
+ </div>
68
  </div>
69
  {% endif %}
70
+
71
+ <div style="margin-top: 40px; padding: 20px; background: #f8f9fa; border-radius: 10px;">
72
+ <h3>About ZeroIG</h3>
73
+ <p>Zero-shot illumination-guided joint denoising and adaptive enhancement for low-light images.</p>
74
+ <p><strong>Features:</strong> No training data required β€’ Joint denoising & enhancement β€’ Prevents over-enhancement</p>
75
+ <p>
76
+ πŸ“„ <a href="https://openaccess.thecvf.com/content/CVPR2024/papers/Shi_ZERO-IG_Zero-Shot_Illumination-Guided_Joint_Denoising_and_Adaptive_Enhancement_for_Low-Light_CVPR_2024_paper.pdf">Research Paper</a> |
77
+ πŸ’» <a href="https://github.com/Doyle59217/ZeroIG">Source Code</a>
78
+ </p>
79
+ </div>
80
  </div>
81
  </body>
82
  </html>
83
  """
84
 
85
+ class ZeroIGProcessor:
86
+ def __init__(self):
87
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
88
+ self.model = self.load_model()
89
+ print(f"ZeroIG initialized on {self.device}")
90
+
91
+ def load_model(self):
92
+ try:
93
+ from model import Finetunemodel, Network
94
+
95
+ # Try to load trained weights
96
+ model_path = "./weights/model.pt"
97
+ if os.path.exists(model_path):
98
+ model = Finetunemodel(model_path)
99
+ print("βœ… Loaded trained ZeroIG model")
100
+ else:
101
+ model = Network()
102
+ print("⚠️ Using ZeroIG with random weights (train a model for better results)")
103
+
104
+ model.to(self.device)
105
+ model.eval()
106
+ return model
107
+
108
+ except Exception as e:
109
+ print(f"❌ Could not load ZeroIG model: {e}")
110
+ return None
111
+
112
+ def enhance_image(self, image):
113
+ """Enhance image using ZeroIG model"""
114
+ try:
115
+ if self.model is None:
116
+ return self.simple_enhance(image), "Using simple enhancement (ZeroIG model not available)"
117
+
118
+ # Resize if too large to prevent memory issues
119
+ original_size = image.size
120
+ max_size = 1024
121
+ if max(image.size) > max_size:
122
+ ratio = max_size / max(image.size)
123
+ new_size = tuple(int(dim * ratio) for dim in image.size)
124
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
125
+
126
+ # Convert to tensor
127
+ transform = transforms.ToTensor()
128
+ input_tensor = transform(image).unsqueeze(0).to(self.device)
129
+
130
+ # Run ZeroIG model
131
+ with torch.no_grad():
132
+ if hasattr(self.model, 'enhance') and hasattr(self.model, 'denoise_1'):
133
+ # Finetunemodel - returns (enhanced, denoised)
134
+ enhanced, denoised = self.model(input_tensor)
135
+ result_tensor = denoised
136
+ status = "βœ… Enhanced with ZeroIG Finetuned model"
137
+ else:
138
+ # Network model - returns multiple outputs
139
+ outputs = self.model(input_tensor)
140
+ result_tensor = outputs[13] # H3 is the final denoised result
141
+ status = "βœ… Enhanced with ZeroIG Network model"
142
+
143
+ # Convert back to PIL
144
+ result_tensor = result_tensor.squeeze(0).cpu().clamp(0, 1)
145
+ enhanced_image = transforms.ToPILImage()(result_tensor)
146
+
147
+ # Resize back to original size if needed
148
+ if enhanced_image.size != original_size:
149
+ enhanced_image = enhanced_image.resize(original_size, Image.Resampling.LANCZOS)
150
+
151
+ return enhanced_image, status
152
+
153
+ except Exception as e:
154
+ print(f"ZeroIG enhancement error: {e}")
155
+ return self.simple_enhance(image), f"⚠️ ZeroIG failed, using simple enhancement: {str(e)}"
156
+
157
+ def simple_enhance(self, image):
158
+ """Fallback simple enhancement"""
159
+ arr = np.array(image).astype(np.float32)
160
+ enhanced = np.clip(arr * 1.8, 0, 255).astype(np.uint8)
161
  return Image.fromarray(enhanced)
162
+
163
+ # Initialize ZeroIG processor
164
+ print("πŸš€ Loading ZeroIG processor...")
165
+ zeroig = ZeroIGProcessor()
166
+
167
+ def image_to_base64(image):
168
+ """Convert PIL image to base64 string"""
169
+ img_buffer = io.BytesIO()
170
+ image.save(img_buffer, format='PNG')
171
+ img_str = base64.b64encode(img_buffer.getvalue()).decode()
172
+ return img_str
173
 
174
  @app.route('/', methods=['GET', 'POST'])
175
  def index():
176
+ original_image = None
177
  result_image = None
178
+ status = None
179
+ error = None
180
 
181
  if request.method == 'POST':
182
  try:
183
  file = request.files['image']
184
  if file:
185
+ print(f"Processing uploaded image: {file.filename}")
186
+
187
  # Open and process image
188
  image = Image.open(file.stream).convert('RGB')
189
+ print(f"Image size: {image.size}")
190
+
191
+ # Store original for comparison
192
+ original_image = image_to_base64(image)
193
+
194
+ # Enhance with ZeroIG
195
+ enhanced_image, enhancement_status = zeroig.enhance_image(image)
196
 
197
+ # Convert result to base64
198
+ result_image = image_to_base64(enhanced_image)
199
+ status = enhancement_status
 
 
200
 
201
  except Exception as e:
202
+ error = f"Error processing image: {str(e)}"
203
+ print(f"Error: {e}")
204
 
205
+ return render_template_string(HTML_TEMPLATE,
206
+ original_image=original_image,
207
+ result_image=result_image,
208
+ status=status,
209
+ error=error)
210
 
211
  if __name__ == '__main__':
212
+ print("πŸš€ Starting ZeroIG Flask app...")
213
  app.run(host='0.0.0.0', port=7860)