SkillForge45 commited on
Commit
672357f
·
verified ·
1 Parent(s): 7b74a37

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +330 -0
model.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import tarfile
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import json
13
+ import math
14
+ from tqdm import tqdm
15
+ from transformers import BertTokenizer, BertModel
16
+ import gradio as gr
17
+
18
+ # Configuration
19
+ class Config:
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ image_size = 64
22
+ batch_size = 32
23
+ num_epochs = 50
24
+ learning_rate = 1e-4
25
+ timesteps = 1000
26
+ text_embed_dim = 768
27
+ num_images_options = [1, 4, 6]
28
+
29
+ # URLs for COCO dataset download
30
+ coco_images_url = "http://images.cocodataset.org/zips/train2017.zip"
31
+ coco_annotations_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
32
+ data_dir = "./coco_data"
33
+ images_dir = os.path.join(data_dir, "train2017")
34
+ annotations_path = os.path.join(data_dir, "annotations/instances_train2017.json")
35
+
36
+ def __init__(self):
37
+ os.makedirs(self.data_dir, exist_ok=True)
38
+
39
+ config = Config()
40
+
41
+ # Download COCO dataset
42
+ def download_and_extract_coco():
43
+ if os.path.exists(config.images_dir) and os.path.exists(config.annotations_path):
44
+ print("COCO dataset already downloaded")
45
+ return
46
+
47
+ print("Downloading COCO dataset...")
48
+
49
+ # Download images
50
+ images_zip_path = os.path.join(config.data_dir, "train2017.zip")
51
+ if not os.path.exists(images_zip_path):
52
+ response = requests.get(config.coco_images_url, stream=True)
53
+ with open(images_zip_path, "wb") as f:
54
+ for chunk in tqdm(response.iter_content(chunk_size=1024)):
55
+ if chunk:
56
+ f.write(chunk)
57
+
58
+ # Download annotations
59
+ annotations_zip_path = os.path.join(config.data_dir, "annotations_trainval2017.zip")
60
+ if not os.path.exists(annotations_zip_path):
61
+ response = requests.get(config.coco_annotations_url, stream=True)
62
+ with open(annotations_zip_path, "wb") as f:
63
+ for chunk in tqdm(response.iter_content(chunk_size=1024)):
64
+ if chunk:
65
+ f.write(chunk)
66
+
67
+ # Extract files
68
+ print("Extracting images...")
69
+ with tarfile.open(images_zip_path, "r:zip") as tar:
70
+ tar.extractall(config.data_dir)
71
+
72
+ print("Extracting annotations...")
73
+ with tarfile.open(annotations_zip_path, "r:zip") as tar:
74
+ tar.extractall(config.data_dir)
75
+
76
+ print("COCO dataset ready")
77
+
78
+ download_and_extract_coco()
79
+
80
+ # Text model
81
+ class TextEncoder(nn.Module):
82
+ def __init__(self):
83
+ super().__init__()
84
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
85
+ self.model = BertModel.from_pretrained('bert-base-uncased')
86
+ for param in self.model.parameters():
87
+ param.requires_grad = False
88
+
89
+ def forward(self, texts):
90
+ inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=64)
91
+ inputs = {k: v.to(config.device) for k, v in inputs.items()}
92
+ outputs = self.model(**inputs)
93
+ return outputs.last_hidden_state[:, 0, :]
94
+
95
+ text_encoder = TextEncoder().to(config.device)
96
+
97
+ # Diffusion model
98
+ class ConditionalUNet(nn.Module):
99
+ def __init__(self):
100
+ super().__init__()
101
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
102
+ self.down1 = DownBlock(64, 128)
103
+ self.down2 = DownBlock(128, 256)
104
+
105
+ self.text_proj = nn.Linear(config.text_embed_dim, 256)
106
+ self.merge = nn.Linear(256 + 256, 256)
107
+
108
+ self.up1 = UpBlock(256, 128)
109
+ self.up2 = UpBlock(128, 64)
110
+ self.final = nn.Conv2d(64, 3, kernel_size=3, padding=1)
111
+
112
+ def forward(self, x, t, text_emb):
113
+ x1 = F.relu(self.conv1(x))
114
+ x2 = self.down1(x1)
115
+ x3 = self.down2(x2)
116
+
117
+ text_emb = self.text_proj(text_emb)
118
+ text_emb = text_emb.unsqueeze(-1).unsqueeze(-1)
119
+ text_emb = text_emb.expand(-1, -1, x3.size(2), x3.size(3))
120
+
121
+ x = torch.cat([x3, text_emb], dim=1)
122
+ b, c, h, w = x.shape
123
+ x = x.permute(0, 2, 3, 1).reshape(b*h*w, c)
124
+ x = self.merge(x)
125
+ x = x.reshape(b, h, w, 256).permute(0, 3, 1, 2)
126
+
127
+ x = self.up1(x)
128
+ x = self.up2(x)
129
+ return self.final(x)
130
+
131
+ class DownBlock(nn.Module):
132
+ def __init__(self, in_ch, out_ch):
133
+ super().__init__()
134
+ self.conv = nn.Sequential(
135
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
136
+ nn.BatchNorm2d(out_ch),
137
+ nn.ReLU(),
138
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
139
+ nn.BatchNorm2d(out_ch),
140
+ nn.ReLU(),
141
+ nn.MaxPool2d(2)
142
+ )
143
+
144
+ def forward(self, x):
145
+ return self.conv(x)
146
+
147
+ class UpBlock(nn.Module):
148
+ def __init__(self, in_ch, out_ch):
149
+ super().__init__()
150
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
151
+ self.conv = nn.Sequential(
152
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
153
+ nn.BatchNorm2d(out_ch),
154
+ nn.ReLU(),
155
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
156
+ nn.BatchNorm2d(out_ch),
157
+ nn.ReLU()
158
+ )
159
+
160
+ def forward(self, x):
161
+ x = self.up(x)
162
+ return self.conv(x)
163
+
164
+ # Diffusion process
165
+ betas = linear_beta_schedule(config.timesteps).to(config.device)
166
+ alphas = 1. - betas
167
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
168
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
169
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
170
+
171
+ def linear_beta_schedule(timesteps):
172
+ beta_start = 0.0001
173
+ beta_end = 0.02
174
+ return torch.linspace(beta_start, beta_end, timesteps)
175
+
176
+ def forward_diffusion_sample(x_0, t, device=config.device):
177
+ noise = torch.randn_like(x_0)
178
+ sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
179
+ sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
180
+ return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
181
+
182
+ # COCO Dataset
183
+ class CocoDataset(Dataset):
184
+ def __init__(self, root_dir, annotations_file, transform=None):
185
+ self.root_dir = root_dir
186
+ self.transform = transform
187
+
188
+ with open(annotations_file, 'r') as f:
189
+ data = json.load(f)
190
+
191
+ self.images = []
192
+ self.captions = []
193
+
194
+ image_id_to_captions = {}
195
+ for ann in data['annotations']:
196
+ if ann['image_id'] not in image_id_to_captions:
197
+ image_id_to_captions[ann['image_id']] = []
198
+ image_id_to_captions[ann['image_id']].append(ann['caption'])
199
+
200
+ for img in data['images']:
201
+ if img['id'] in image_id_to_captions:
202
+ self.images.append(img)
203
+ self.captions.append(image_id_to_captions[img['id']][0])
204
+
205
+ def __len__(self):
206
+ return len(self.images)
207
+
208
+ def __getitem__(self, idx):
209
+ img_path = os.path.join(self.root_dir, self.images[idx]['file_name'])
210
+ image = Image.open(img_path).convert('RGB')
211
+ caption = self.captions[idx]
212
+
213
+ if self.transform:
214
+ image = self.transform(image)
215
+
216
+ return image, caption
217
+
218
+ # Transformations
219
+ transform = transforms.Compose([
220
+ transforms.Resize((config.image_size, config.image_size)),
221
+ transforms.ToTensor(),
222
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
223
+ ])
224
+
225
+ # Model initialization
226
+ model = ConditionalUNet().to(config.device)
227
+ optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
228
+
229
+ # Training
230
+ def train():
231
+ dataset = CocoDataset(config.images_dir, config.annotations_path, transform)
232
+ dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
233
+
234
+ for epoch in range(config.num_epochs):
235
+ for batch_idx, (images, captions) in enumerate(tqdm(dataloader)):
236
+ images = images.to(config.device)
237
+
238
+ # Get text embeddings
239
+ text_emb = text_encoder(captions)
240
+
241
+ # Sample random timesteps
242
+ t = torch.randint(0, config.timesteps, (images.size(0),), device=config.device)
243
+
244
+ # Forward diffusion
245
+ x_noisy, noise = forward_diffusion_sample(images, t)
246
+
247
+ # Predict noise
248
+ pred_noise = model(x_noisy, t, text_emb)
249
+
250
+ # Loss and backpropagation
251
+ loss = F.mse_loss(pred_noise, noise)
252
+ optimizer.zero_grad()
253
+ loss.backward()
254
+ optimizer.step()
255
+
256
+ if batch_idx % 100 == 0:
257
+ print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
258
+
259
+ # Save model
260
+ torch.save(model.state_dict(), f"model_epoch_{epoch}.pth")
261
+
262
+ # Generation
263
+ @torch.no_grad()
264
+ def generate(prompt, num_images=1):
265
+ model.eval()
266
+ num_images = int(num_images)
267
+
268
+ text_emb = text_encoder([prompt]*num_images)
269
+ x = torch.randn((num_images, 3, config.image_size, config.image_size)).to(config.device)
270
+
271
+ for t in reversed(range(config.timesteps)):
272
+ t_tensor = torch.full((num_images,), t, device=config.device)
273
+ pred_noise = model(x, t_tensor, text_emb)
274
+ alpha_t = alphas[t].view(1, 1, 1, 1)
275
+ alpha_cumprod_t = alphas_cumprod[t].view(1, 1, 1, 1)
276
+ beta_t = betas[t].view(1, 1, 1, 1)
277
+
278
+ if t > 0:
279
+ noise = torch.randn_like(x)
280
+ else:
281
+ noise = torch.zeros_like(x)
282
+
283
+ x = (1 / torch.sqrt(alpha_t)) * (
284
+ x - ((1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)) * pred_noise
285
+ ) + torch.sqrt(beta_t) * noise
286
+
287
+ x = torch.clamp(x, -1, 1)
288
+ x = (x + 1) / 2
289
+
290
+ images = []
291
+ for img in x:
292
+ img = transforms.ToPILImage()(img.cpu())
293
+ images.append(img)
294
+
295
+ return images
296
+
297
+ # GUI
298
+ def generate_and_display(prompt, num_images):
299
+ images = generate(prompt, num_images)
300
+
301
+ fig, axes = plt.subplots(1, len(images), figsize=(5*len(images), 5))
302
+ if len(images) == 1:
303
+ axes.imshow(images[0])
304
+ axes.axis('off')
305
+ else:
306
+ for ax, img in zip(axes, images):
307
+ ax.imshow(img)
308
+ ax.axis('off')
309
+ plt.tight_layout()
310
+ return fig
311
+
312
+ with gr.Blocks() as demo:
313
+ gr.Markdown("## GPUDiff-V1: diffussion powerful image generator!")
314
+ with gr.Row():
315
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter image description...")
316
+ num_select = gr.Dropdown(choices=config.num_images_options, value=1, label="Number of images")
317
+ generate_btn = gr.Button("Generate")
318
+ output = gr.Plot()
319
+
320
+ generate_btn.click(
321
+ fn=generate_and_display,
322
+ inputs=[prompt_input, num_select],
323
+ outputs=output
324
+ )
325
+
326
+ if __name__ == "__main__":
327
+
328
+ train()
329
+
330
+ demo.launch()