syedaoon commited on
Commit
2d074f9
Β·
verified Β·
1 Parent(s): edf4cf6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -10
app.py CHANGED
@@ -6,6 +6,7 @@ import base64
6
  import torch
7
  import torchvision.transforms as transforms
8
  import os
 
9
 
10
  app = Flask(__name__)
11
 
@@ -90,68 +91,84 @@ class ZeroIGProcessor:
90
 
91
  def load_model(self):
92
  try:
 
93
  from model import Finetunemodel, Network
94
 
95
  # Try to load trained weights
96
  model_path = "./weights/model.pt"
97
  if os.path.exists(model_path):
 
98
  model = Finetunemodel(model_path)
99
- print("βœ… Loaded trained ZeroIG model")
100
  else:
 
101
  model = Network()
102
- print("⚠️ Using ZeroIG with random weights (train a model for better results)")
103
 
104
  model.to(self.device)
105
  model.eval()
 
106
  return model
107
 
 
 
 
 
108
  except Exception as e:
109
  print(f"❌ Could not load ZeroIG model: {e}")
110
  return None
111
 
112
  def enhance_image(self, image):
113
- """Enhance image using ZeroIG model"""
114
  try:
115
  if self.model is None:
116
- return self.simple_enhance(image), "Using simple enhancement (ZeroIG model not available)"
117
 
118
  # Resize if too large to prevent memory issues
119
  original_size = image.size
120
- max_size = 1024
121
  if max(image.size) > max_size:
122
  ratio = max_size / max(image.size)
123
  new_size = tuple(int(dim * ratio) for dim in image.size)
124
  image = image.resize(new_size, Image.Resampling.LANCZOS)
 
125
 
126
  # Convert to tensor
127
  transform = transforms.ToTensor()
128
  input_tensor = transform(image).unsqueeze(0).to(self.device)
 
129
 
130
- # Run ZeroIG model
131
  with torch.no_grad():
132
  if hasattr(self.model, 'enhance') and hasattr(self.model, 'denoise_1'):
133
  # Finetunemodel - returns (enhanced, denoised)
134
  enhanced, denoised = self.model(input_tensor)
135
- result_tensor = denoised
136
- status = "βœ… Enhanced with ZeroIG Finetuned model"
 
137
  else:
138
  # Network model - returns multiple outputs
139
  outputs = self.model(input_tensor)
140
  result_tensor = outputs[13] # H3 is the final denoised result
141
  status = "βœ… Enhanced with ZeroIG Network model"
 
142
 
143
  # Convert back to PIL
144
  result_tensor = result_tensor.squeeze(0).cpu().clamp(0, 1)
145
  enhanced_image = transforms.ToPILImage()(result_tensor)
 
146
 
147
  # Resize back to original size if needed
148
- if enhanced_image.size != original_size:
149
  enhanced_image = enhanced_image.resize(original_size, Image.Resampling.LANCZOS)
 
150
 
151
  return enhanced_image, status
152
 
153
  except Exception as e:
154
  print(f"ZeroIG enhancement error: {e}")
 
 
155
  return self.simple_enhance(image), f"⚠️ ZeroIG failed, using simple enhancement: {str(e)}"
156
 
157
  def simple_enhance(self, image):
@@ -191,7 +208,7 @@ def index():
191
  # Store original for comparison
192
  original_image = image_to_base64(image)
193
 
194
- # Enhance with ZeroIG
195
  enhanced_image, enhancement_status = zeroig.enhance_image(image)
196
 
197
  # Convert result to base64
@@ -201,6 +218,8 @@ def index():
201
  except Exception as e:
202
  error = f"Error processing image: {str(e)}"
203
  print(f"Error: {e}")
 
 
204
 
205
  return render_template_string(HTML_TEMPLATE,
206
  original_image=original_image,
 
6
  import torch
7
  import torchvision.transforms as transforms
8
  import os
9
+ import sys
10
 
11
  app = Flask(__name__)
12
 
 
91
 
92
  def load_model(self):
93
  try:
94
+ # Import your uploaded ZeroIG files
95
  from model import Finetunemodel, Network
96
 
97
  # Try to load trained weights
98
  model_path = "./weights/model.pt"
99
  if os.path.exists(model_path):
100
+ print(f"Found model weights at {model_path}")
101
  model = Finetunemodel(model_path)
102
+ print("βœ… Loaded ZeroIG Finetunemodel with trained weights")
103
  else:
104
+ print("No trained weights found, using Network with random initialization")
105
  model = Network()
106
+ print("⚠️ Using ZeroIG Network with random weights")
107
 
108
  model.to(self.device)
109
  model.eval()
110
+ print(f"Model moved to {self.device}")
111
  return model
112
 
113
+ except ImportError as e:
114
+ print(f"❌ Could not import ZeroIG modules: {e}")
115
+ print("Make sure you have uploaded: model.py, loss.py, utils.py")
116
+ return None
117
  except Exception as e:
118
  print(f"❌ Could not load ZeroIG model: {e}")
119
  return None
120
 
121
  def enhance_image(self, image):
122
+ """Enhance image using your ZeroIG model"""
123
  try:
124
  if self.model is None:
125
+ return self.simple_enhance(image), "❌ ZeroIG model not available - using simple enhancement"
126
 
127
  # Resize if too large to prevent memory issues
128
  original_size = image.size
129
+ max_size = 800 # Adjust based on your needs
130
  if max(image.size) > max_size:
131
  ratio = max_size / max(image.size)
132
  new_size = tuple(int(dim * ratio) for dim in image.size)
133
  image = image.resize(new_size, Image.Resampling.LANCZOS)
134
+ print(f"Resized image from {original_size} to {image.size}")
135
 
136
  # Convert to tensor
137
  transform = transforms.ToTensor()
138
  input_tensor = transform(image).unsqueeze(0).to(self.device)
139
+ print(f"Input tensor shape: {input_tensor.shape}")
140
 
141
+ # Run your ZeroIG model
142
  with torch.no_grad():
143
  if hasattr(self.model, 'enhance') and hasattr(self.model, 'denoise_1'):
144
  # Finetunemodel - returns (enhanced, denoised)
145
  enhanced, denoised = self.model(input_tensor)
146
+ result_tensor = denoised # Use denoised output
147
+ status = "βœ… Enhanced with ZeroIG Finetunemodel"
148
+ print("Used Finetunemodel")
149
  else:
150
  # Network model - returns multiple outputs
151
  outputs = self.model(input_tensor)
152
  result_tensor = outputs[13] # H3 is the final denoised result
153
  status = "βœ… Enhanced with ZeroIG Network model"
154
+ print("Used Network model")
155
 
156
  # Convert back to PIL
157
  result_tensor = result_tensor.squeeze(0).cpu().clamp(0, 1)
158
  enhanced_image = transforms.ToPILImage()(result_tensor)
159
+ print(f"Output image size: {enhanced_image.size}")
160
 
161
  # Resize back to original size if needed
162
+ if enhanced_image.size != original_size and original_size != image.size:
163
  enhanced_image = enhanced_image.resize(original_size, Image.Resampling.LANCZOS)
164
+ print(f"Resized back to original size: {enhanced_image.size}")
165
 
166
  return enhanced_image, status
167
 
168
  except Exception as e:
169
  print(f"ZeroIG enhancement error: {e}")
170
+ import traceback
171
+ traceback.print_exc()
172
  return self.simple_enhance(image), f"⚠️ ZeroIG failed, using simple enhancement: {str(e)}"
173
 
174
  def simple_enhance(self, image):
 
208
  # Store original for comparison
209
  original_image = image_to_base64(image)
210
 
211
+ # Enhance with your ZeroIG model
212
  enhanced_image, enhancement_status = zeroig.enhance_image(image)
213
 
214
  # Convert result to base64
 
218
  except Exception as e:
219
  error = f"Error processing image: {str(e)}"
220
  print(f"Error: {e}")
221
+ import traceback
222
+ traceback.print_exc()
223
 
224
  return render_template_string(HTML_TEMPLATE,
225
  original_image=original_image,