import os import requests import tarfile import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt import numpy as np import json import math from tqdm import tqdm from transformers import BertTokenizer, BertModel import gradio as gr # Configuration class Config: device = "cuda" if torch.cuda.is_available() else "cpu" image_size = 64 batch_size = 32 num_epochs = 50 learning_rate = 1e-4 timesteps = 1000 text_embed_dim = 768 num_images_options = [1, 4, 6] # URLs for COCO dataset download coco_images_url = "http://images.cocodataset.org/zips/train2017.zip" coco_annotations_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip" data_dir = "./coco_data" images_dir = os.path.join(data_dir, "train2017") annotations_path = os.path.join(data_dir, "annotations/instances_train2017.json") def __init__(self): os.makedirs(self.data_dir, exist_ok=True) config = Config() # Download COCO dataset def download_and_extract_coco(): if os.path.exists(config.images_dir) and os.path.exists(config.annotations_path): print("COCO dataset already downloaded") return print("Downloading COCO dataset...") # Download images images_zip_path = os.path.join(config.data_dir, "train2017.zip") if not os.path.exists(images_zip_path): response = requests.get(config.coco_images_url, stream=True) with open(images_zip_path, "wb") as f: for chunk in tqdm(response.iter_content(chunk_size=1024)): if chunk: f.write(chunk) # Download annotations annotations_zip_path = os.path.join(config.data_dir, "annotations_trainval2017.zip") if not os.path.exists(annotations_zip_path): response = requests.get(config.coco_annotations_url, stream=True) with open(annotations_zip_path, "wb") as f: for chunk in tqdm(response.iter_content(chunk_size=1024)): if chunk: f.write(chunk) # Extract files print("Extracting images...") with tarfile.open(images_zip_path, "r:zip") as tar: tar.extractall(config.data_dir) print("Extracting annotations...") with tarfile.open(annotations_zip_path, "r:zip") as tar: tar.extractall(config.data_dir) print("COCO dataset ready") download_and_extract_coco() # Text model class TextEncoder(nn.Module): def __init__(self): super().__init__() self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') self.model = BertModel.from_pretrained('bert-base-uncased') for param in self.model.parameters(): param.requires_grad = False def forward(self, texts): inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=64) inputs = {k: v.to(config.device) for k, v in inputs.items()} outputs = self.model(**inputs) return outputs.last_hidden_state[:, 0, :] text_encoder = TextEncoder().to(config.device) # Diffusion model class ConditionalUNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.down1 = DownBlock(64, 128) self.down2 = DownBlock(128, 256) self.text_proj = nn.Linear(config.text_embed_dim, 256) self.merge = nn.Linear(256 + 256, 256) self.up1 = UpBlock(256, 128) self.up2 = UpBlock(128, 64) self.final = nn.Conv2d(64, 3, kernel_size=3, padding=1) def forward(self, x, t, text_emb): x1 = F.relu(self.conv1(x)) x2 = self.down1(x1) x3 = self.down2(x2) text_emb = self.text_proj(text_emb) text_emb = text_emb.unsqueeze(-1).unsqueeze(-1) text_emb = text_emb.expand(-1, -1, x3.size(2), x3.size(3)) x = torch.cat([x3, text_emb], dim=1) b, c, h, w = x.shape x = x.permute(0, 2, 3, 1).reshape(b*h*w, c) x = self.merge(x) x = x.reshape(b, h, w, 256).permute(0, 3, 1, 2) x = self.up1(x) x = self.up2(x) return self.final(x) class DownBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(), nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(), nn.MaxPool2d(2) ) def forward(self, x): return self.conv(x) class UpBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(), nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU() ) def forward(self, x): x = self.up(x) return self.conv(x) # Diffusion process betas = linear_beta_schedule(config.timesteps).to(config.device) alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, dim=0) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) def linear_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start, beta_end, timesteps) def forward_diffusion_sample(x_0, t, device=config.device): noise = torch.randn_like(x_0) sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1) sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1) return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise # COCO Dataset class CocoDataset(Dataset): def __init__(self, root_dir, annotations_file, transform=None): self.root_dir = root_dir self.transform = transform with open(annotations_file, 'r') as f: data = json.load(f) self.images = [] self.captions = [] image_id_to_captions = {} for ann in data['annotations']: if ann['image_id'] not in image_id_to_captions: image_id_to_captions[ann['image_id']] = [] image_id_to_captions[ann['image_id']].append(ann['caption']) for img in data['images']: if img['id'] in image_id_to_captions: self.images.append(img) self.captions.append(image_id_to_captions[img['id']][0]) def __len__(self): return len(self.images) def __getitem__(self, idx): img_path = os.path.join(self.root_dir, self.images[idx]['file_name']) image = Image.open(img_path).convert('RGB') caption = self.captions[idx] if self.transform: image = self.transform(image) return image, caption # Transformations transform = transforms.Compose([ transforms.Resize((config.image_size, config.image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Model initialization model = ConditionalUNet().to(config.device) optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) # Training def train(): dataset = CocoDataset(config.images_dir, config.annotations_path, transform) dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) for epoch in range(config.num_epochs): for batch_idx, (images, captions) in enumerate(tqdm(dataloader)): images = images.to(config.device) # Get text embeddings text_emb = text_encoder(captions) # Sample random timesteps t = torch.randint(0, config.timesteps, (images.size(0),), device=config.device) # Forward diffusion x_noisy, noise = forward_diffusion_sample(images, t) # Predict noise pred_noise = model(x_noisy, t, text_emb) # Loss and backpropagation loss = F.mse_loss(pred_noise, noise) optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % 100 == 0: print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}") # Save model torch.save(model.state_dict(), f"model_epoch_{epoch}.pth") # Generation @torch.no_grad() def generate(prompt, num_images=1): model.eval() num_images = int(num_images) text_emb = text_encoder([prompt]*num_images) x = torch.randn((num_images, 3, config.image_size, config.image_size)).to(config.device) for t in reversed(range(config.timesteps)): t_tensor = torch.full((num_images,), t, device=config.device) pred_noise = model(x, t_tensor, text_emb) alpha_t = alphas[t].view(1, 1, 1, 1) alpha_cumprod_t = alphas_cumprod[t].view(1, 1, 1, 1) beta_t = betas[t].view(1, 1, 1, 1) if t > 0: noise = torch.randn_like(x) else: noise = torch.zeros_like(x) x = (1 / torch.sqrt(alpha_t)) * ( x - ((1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)) * pred_noise ) + torch.sqrt(beta_t) * noise x = torch.clamp(x, -1, 1) x = (x + 1) / 2 images = [] for img in x: img = transforms.ToPILImage()(img.cpu()) images.append(img) return images # GUI def generate_and_display(prompt, num_images): images = generate(prompt, num_images) fig, axes = plt.subplots(1, len(images), figsize=(5*len(images), 5)) if len(images) == 1: axes.imshow(images[0]) axes.axis('off') else: for ax, img in zip(axes, images): ax.imshow(img) ax.axis('off') plt.tight_layout() return fig with gr.Blocks() as demo: gr.Markdown("## GPUDiff-V1: diffussion powerful image generator!") with gr.Row(): prompt_input = gr.Textbox(label="Prompt", placeholder="Enter image description...") num_select = gr.Dropdown(choices=config.num_images_options, value=1, label="Number of images") generate_btn = gr.Button("Generate") output = gr.Plot() generate_btn.click( fn=generate_and_display, inputs=[prompt_input, num_select], outputs=output ) if __name__ == "__main__": train() demo.launch()