Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
import tarfile
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
from torchvision import transforms
|
9 |
+
from PIL import Image
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
import json
|
13 |
+
import math
|
14 |
+
from tqdm import tqdm
|
15 |
+
from transformers import BertTokenizer, BertModel
|
16 |
+
import gradio as gr
|
17 |
+
|
18 |
+
# Configuration
|
19 |
+
class Config:
|
20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
image_size = 64
|
22 |
+
batch_size = 32
|
23 |
+
num_epochs = 50
|
24 |
+
learning_rate = 1e-4
|
25 |
+
timesteps = 1000
|
26 |
+
text_embed_dim = 768
|
27 |
+
num_images_options = [1, 4, 6]
|
28 |
+
|
29 |
+
# URLs for COCO dataset download
|
30 |
+
coco_images_url = "http://images.cocodataset.org/zips/train2017.zip"
|
31 |
+
coco_annotations_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
|
32 |
+
data_dir = "./coco_data"
|
33 |
+
images_dir = os.path.join(data_dir, "train2017")
|
34 |
+
annotations_path = os.path.join(data_dir, "annotations/instances_train2017.json")
|
35 |
+
|
36 |
+
def __init__(self):
|
37 |
+
os.makedirs(self.data_dir, exist_ok=True)
|
38 |
+
|
39 |
+
config = Config()
|
40 |
+
|
41 |
+
# Download COCO dataset
|
42 |
+
def download_and_extract_coco():
|
43 |
+
if os.path.exists(config.images_dir) and os.path.exists(config.annotations_path):
|
44 |
+
print("COCO dataset already downloaded")
|
45 |
+
return
|
46 |
+
|
47 |
+
print("Downloading COCO dataset...")
|
48 |
+
|
49 |
+
# Download images
|
50 |
+
images_zip_path = os.path.join(config.data_dir, "train2017.zip")
|
51 |
+
if not os.path.exists(images_zip_path):
|
52 |
+
response = requests.get(config.coco_images_url, stream=True)
|
53 |
+
with open(images_zip_path, "wb") as f:
|
54 |
+
for chunk in tqdm(response.iter_content(chunk_size=1024)):
|
55 |
+
if chunk:
|
56 |
+
f.write(chunk)
|
57 |
+
|
58 |
+
# Download annotations
|
59 |
+
annotations_zip_path = os.path.join(config.data_dir, "annotations_trainval2017.zip")
|
60 |
+
if not os.path.exists(annotations_zip_path):
|
61 |
+
response = requests.get(config.coco_annotations_url, stream=True)
|
62 |
+
with open(annotations_zip_path, "wb") as f:
|
63 |
+
for chunk in tqdm(response.iter_content(chunk_size=1024)):
|
64 |
+
if chunk:
|
65 |
+
f.write(chunk)
|
66 |
+
|
67 |
+
# Extract files
|
68 |
+
print("Extracting images...")
|
69 |
+
with tarfile.open(images_zip_path, "r:zip") as tar:
|
70 |
+
tar.extractall(config.data_dir)
|
71 |
+
|
72 |
+
print("Extracting annotations...")
|
73 |
+
with tarfile.open(annotations_zip_path, "r:zip") as tar:
|
74 |
+
tar.extractall(config.data_dir)
|
75 |
+
|
76 |
+
print("COCO dataset ready")
|
77 |
+
|
78 |
+
download_and_extract_coco()
|
79 |
+
|
80 |
+
# Text model
|
81 |
+
class TextEncoder(nn.Module):
|
82 |
+
def __init__(self):
|
83 |
+
super().__init__()
|
84 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
85 |
+
self.model = BertModel.from_pretrained('bert-base-uncased')
|
86 |
+
for param in self.model.parameters():
|
87 |
+
param.requires_grad = False
|
88 |
+
|
89 |
+
def forward(self, texts):
|
90 |
+
inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=64)
|
91 |
+
inputs = {k: v.to(config.device) for k, v in inputs.items()}
|
92 |
+
outputs = self.model(**inputs)
|
93 |
+
return outputs.last_hidden_state[:, 0, :]
|
94 |
+
|
95 |
+
text_encoder = TextEncoder().to(config.device)
|
96 |
+
|
97 |
+
# Diffusion model
|
98 |
+
class ConditionalUNet(nn.Module):
|
99 |
+
def __init__(self):
|
100 |
+
super().__init__()
|
101 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
|
102 |
+
self.down1 = DownBlock(64, 128)
|
103 |
+
self.down2 = DownBlock(128, 256)
|
104 |
+
|
105 |
+
self.text_proj = nn.Linear(config.text_embed_dim, 256)
|
106 |
+
self.merge = nn.Linear(256 + 256, 256)
|
107 |
+
|
108 |
+
self.up1 = UpBlock(256, 128)
|
109 |
+
self.up2 = UpBlock(128, 64)
|
110 |
+
self.final = nn.Conv2d(64, 3, kernel_size=3, padding=1)
|
111 |
+
|
112 |
+
def forward(self, x, t, text_emb):
|
113 |
+
x1 = F.relu(self.conv1(x))
|
114 |
+
x2 = self.down1(x1)
|
115 |
+
x3 = self.down2(x2)
|
116 |
+
|
117 |
+
text_emb = self.text_proj(text_emb)
|
118 |
+
text_emb = text_emb.unsqueeze(-1).unsqueeze(-1)
|
119 |
+
text_emb = text_emb.expand(-1, -1, x3.size(2), x3.size(3))
|
120 |
+
|
121 |
+
x = torch.cat([x3, text_emb], dim=1)
|
122 |
+
b, c, h, w = x.shape
|
123 |
+
x = x.permute(0, 2, 3, 1).reshape(b*h*w, c)
|
124 |
+
x = self.merge(x)
|
125 |
+
x = x.reshape(b, h, w, 256).permute(0, 3, 1, 2)
|
126 |
+
|
127 |
+
x = self.up1(x)
|
128 |
+
x = self.up2(x)
|
129 |
+
return self.final(x)
|
130 |
+
|
131 |
+
class DownBlock(nn.Module):
|
132 |
+
def __init__(self, in_ch, out_ch):
|
133 |
+
super().__init__()
|
134 |
+
self.conv = nn.Sequential(
|
135 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
136 |
+
nn.BatchNorm2d(out_ch),
|
137 |
+
nn.ReLU(),
|
138 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
139 |
+
nn.BatchNorm2d(out_ch),
|
140 |
+
nn.ReLU(),
|
141 |
+
nn.MaxPool2d(2)
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
return self.conv(x)
|
146 |
+
|
147 |
+
class UpBlock(nn.Module):
|
148 |
+
def __init__(self, in_ch, out_ch):
|
149 |
+
super().__init__()
|
150 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
151 |
+
self.conv = nn.Sequential(
|
152 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
153 |
+
nn.BatchNorm2d(out_ch),
|
154 |
+
nn.ReLU(),
|
155 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
156 |
+
nn.BatchNorm2d(out_ch),
|
157 |
+
nn.ReLU()
|
158 |
+
)
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
x = self.up(x)
|
162 |
+
return self.conv(x)
|
163 |
+
|
164 |
+
# Diffusion process
|
165 |
+
betas = linear_beta_schedule(config.timesteps).to(config.device)
|
166 |
+
alphas = 1. - betas
|
167 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
168 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
169 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
|
170 |
+
|
171 |
+
def linear_beta_schedule(timesteps):
|
172 |
+
beta_start = 0.0001
|
173 |
+
beta_end = 0.02
|
174 |
+
return torch.linspace(beta_start, beta_end, timesteps)
|
175 |
+
|
176 |
+
def forward_diffusion_sample(x_0, t, device=config.device):
|
177 |
+
noise = torch.randn_like(x_0)
|
178 |
+
sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
|
179 |
+
sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
|
180 |
+
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
|
181 |
+
|
182 |
+
# COCO Dataset
|
183 |
+
class CocoDataset(Dataset):
|
184 |
+
def __init__(self, root_dir, annotations_file, transform=None):
|
185 |
+
self.root_dir = root_dir
|
186 |
+
self.transform = transform
|
187 |
+
|
188 |
+
with open(annotations_file, 'r') as f:
|
189 |
+
data = json.load(f)
|
190 |
+
|
191 |
+
self.images = []
|
192 |
+
self.captions = []
|
193 |
+
|
194 |
+
image_id_to_captions = {}
|
195 |
+
for ann in data['annotations']:
|
196 |
+
if ann['image_id'] not in image_id_to_captions:
|
197 |
+
image_id_to_captions[ann['image_id']] = []
|
198 |
+
image_id_to_captions[ann['image_id']].append(ann['caption'])
|
199 |
+
|
200 |
+
for img in data['images']:
|
201 |
+
if img['id'] in image_id_to_captions:
|
202 |
+
self.images.append(img)
|
203 |
+
self.captions.append(image_id_to_captions[img['id']][0])
|
204 |
+
|
205 |
+
def __len__(self):
|
206 |
+
return len(self.images)
|
207 |
+
|
208 |
+
def __getitem__(self, idx):
|
209 |
+
img_path = os.path.join(self.root_dir, self.images[idx]['file_name'])
|
210 |
+
image = Image.open(img_path).convert('RGB')
|
211 |
+
caption = self.captions[idx]
|
212 |
+
|
213 |
+
if self.transform:
|
214 |
+
image = self.transform(image)
|
215 |
+
|
216 |
+
return image, caption
|
217 |
+
|
218 |
+
# Transformations
|
219 |
+
transform = transforms.Compose([
|
220 |
+
transforms.Resize((config.image_size, config.image_size)),
|
221 |
+
transforms.ToTensor(),
|
222 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
223 |
+
])
|
224 |
+
|
225 |
+
# Model initialization
|
226 |
+
model = ConditionalUNet().to(config.device)
|
227 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
|
228 |
+
|
229 |
+
# Training
|
230 |
+
def train():
|
231 |
+
dataset = CocoDataset(config.images_dir, config.annotations_path, transform)
|
232 |
+
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
|
233 |
+
|
234 |
+
for epoch in range(config.num_epochs):
|
235 |
+
for batch_idx, (images, captions) in enumerate(tqdm(dataloader)):
|
236 |
+
images = images.to(config.device)
|
237 |
+
|
238 |
+
# Get text embeddings
|
239 |
+
text_emb = text_encoder(captions)
|
240 |
+
|
241 |
+
# Sample random timesteps
|
242 |
+
t = torch.randint(0, config.timesteps, (images.size(0),), device=config.device)
|
243 |
+
|
244 |
+
# Forward diffusion
|
245 |
+
x_noisy, noise = forward_diffusion_sample(images, t)
|
246 |
+
|
247 |
+
# Predict noise
|
248 |
+
pred_noise = model(x_noisy, t, text_emb)
|
249 |
+
|
250 |
+
# Loss and backpropagation
|
251 |
+
loss = F.mse_loss(pred_noise, noise)
|
252 |
+
optimizer.zero_grad()
|
253 |
+
loss.backward()
|
254 |
+
optimizer.step()
|
255 |
+
|
256 |
+
if batch_idx % 100 == 0:
|
257 |
+
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
|
258 |
+
|
259 |
+
# Save model
|
260 |
+
torch.save(model.state_dict(), f"model_epoch_{epoch}.pth")
|
261 |
+
|
262 |
+
# Generation
|
263 |
+
@torch.no_grad()
|
264 |
+
def generate(prompt, num_images=1):
|
265 |
+
model.eval()
|
266 |
+
num_images = int(num_images)
|
267 |
+
|
268 |
+
text_emb = text_encoder([prompt]*num_images)
|
269 |
+
x = torch.randn((num_images, 3, config.image_size, config.image_size)).to(config.device)
|
270 |
+
|
271 |
+
for t in reversed(range(config.timesteps)):
|
272 |
+
t_tensor = torch.full((num_images,), t, device=config.device)
|
273 |
+
pred_noise = model(x, t_tensor, text_emb)
|
274 |
+
alpha_t = alphas[t].view(1, 1, 1, 1)
|
275 |
+
alpha_cumprod_t = alphas_cumprod[t].view(1, 1, 1, 1)
|
276 |
+
beta_t = betas[t].view(1, 1, 1, 1)
|
277 |
+
|
278 |
+
if t > 0:
|
279 |
+
noise = torch.randn_like(x)
|
280 |
+
else:
|
281 |
+
noise = torch.zeros_like(x)
|
282 |
+
|
283 |
+
x = (1 / torch.sqrt(alpha_t)) * (
|
284 |
+
x - ((1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)) * pred_noise
|
285 |
+
) + torch.sqrt(beta_t) * noise
|
286 |
+
|
287 |
+
x = torch.clamp(x, -1, 1)
|
288 |
+
x = (x + 1) / 2
|
289 |
+
|
290 |
+
images = []
|
291 |
+
for img in x:
|
292 |
+
img = transforms.ToPILImage()(img.cpu())
|
293 |
+
images.append(img)
|
294 |
+
|
295 |
+
return images
|
296 |
+
|
297 |
+
# GUI
|
298 |
+
def generate_and_display(prompt, num_images):
|
299 |
+
images = generate(prompt, num_images)
|
300 |
+
|
301 |
+
fig, axes = plt.subplots(1, len(images), figsize=(5*len(images), 5))
|
302 |
+
if len(images) == 1:
|
303 |
+
axes.imshow(images[0])
|
304 |
+
axes.axis('off')
|
305 |
+
else:
|
306 |
+
for ax, img in zip(axes, images):
|
307 |
+
ax.imshow(img)
|
308 |
+
ax.axis('off')
|
309 |
+
plt.tight_layout()
|
310 |
+
return fig
|
311 |
+
|
312 |
+
with gr.Blocks() as demo:
|
313 |
+
gr.Markdown("## GPUDiff-V1: diffussion powerful image generator!")
|
314 |
+
with gr.Row():
|
315 |
+
prompt_input = gr.Textbox(label="Prompt", placeholder="Enter image description...")
|
316 |
+
num_select = gr.Dropdown(choices=config.num_images_options, value=1, label="Number of images")
|
317 |
+
generate_btn = gr.Button("Generate")
|
318 |
+
output = gr.Plot()
|
319 |
+
|
320 |
+
generate_btn.click(
|
321 |
+
fn=generate_and_display,
|
322 |
+
inputs=[prompt_input, num_select],
|
323 |
+
outputs=output
|
324 |
+
)
|
325 |
+
|
326 |
+
if __name__ == "__main__":
|
327 |
+
|
328 |
+
train()
|
329 |
+
|
330 |
+
demo.launch()
|