import numpy as np import torch import torch.nn as nn import gradio as gr from PIL import Image, ImageFilter, ImageOps, ImageChops import torchvision.transforms as transforms import os # 🎨 Dictionary of all filters with their corresponding emojis FILTERS = { "Standard": "πŸ“„", "Invert": "⚫βšͺ", "Blur": "🌫️", "Sharpen": "πŸ”ͺ", "Contour": "πŸ—ΊοΈ", "Detail": "πŸ”", "EdgeEnhance": "πŸ“", "EdgeEnhanceMore": "πŸ“", "Emboss": "🏞️", "FindEdges": "πŸ•΅οΈ", "Smooth": "🌊", "SmoothMore": "πŸ’§", "Solarize": "β˜€οΈ", "Posterize1": "πŸ–ΌοΈ1", "Posterize2": "πŸ–ΌοΈ2", "Posterize3": "πŸ–ΌοΈ3", "Posterize4": "πŸ–ΌοΈ4", "Equalize": "βš–οΈ", "AutoContrast": "πŸ”§", "Thick1": "πŸ’ͺ1", "Thick2": "πŸ’ͺ2", "Thick3": "πŸ’ͺ3", "Thin1": "πŸƒ1", "Thin2": "πŸƒ2", "Thin3": "πŸƒ3", "RedOnWhite": "πŸ”΄", "OrangeOnWhite": "🟠", "YellowOnWhite": "🟑", "GreenOnWhite": "🟒", "BlueOnWhite": "πŸ”΅", "PurpleOnWhite": "🟣", "PinkOnWhite": "🌸", "CyanOnWhite": "🩡", "MagentaOnWhite": "πŸŸͺ", "BrownOnWhite": "🀎", "GrayOnWhite": "🩢", "WhiteOnBlack": "βšͺ", "RedOnBlack": "πŸ”΄βš«", "OrangeOnBlack": "🟠⚫", "YellowOnBlack": "🟑⚫", "GreenOnBlack": "🟒⚫", "BlueOnBlack": "πŸ”΅βš«", "PurpleOnBlack": "🟣⚫", "PinkOnBlack": "🌸⚫", "CyanOnBlack": "🩡⚫", "MagentaOnBlack": "πŸŸͺ⚫", "BrownOnBlack": "🀎⚫", "GrayOnBlack": "🩢⚫", "Multiply": "βœ–οΈ", "Screen": "πŸ–₯️", "Overlay": "πŸ”²", "Add": "βž•", "Subtract": "βž–", "Difference": "β‰ ", "Darker": "πŸŒ‘", "Lighter": "πŸŒ•", "SoftLight": "πŸ’‘", "HardLight": "πŸ”¦", "Binary": "πŸŒ“", "Noise": "❄️" } # 🧠 Neural network layers norm_layer = nn.InstanceNorm2d # 🧱 Building block for the generator class ResidualBlock(nn.Module): def __init__(self, in_features): super(ResidualBlock, self).__init__() conv_block = [ nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features) ] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) # 🎨 Generator model for creating line drawings class Generator(nn.Module): def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): super(Generator, self).__init__() # Initial convolution block model0 = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True) ] self.model0 = nn.Sequential(*model0) # Downsampling model1 = [] in_features = 64 out_features = in_features*2 for _ in range(2): model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features*2 self.model1 = nn.Sequential(*model1) # Residual blocks model2 = [] for _ in range(n_residual_blocks): model2 += [ResidualBlock(in_features)] self.model2 = nn.Sequential(*model2) # Upsampling model3 = [] out_features = in_features//2 for _ in range(2): model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features//2 self.model3 = nn.Sequential(*model3) # Output layer model4 = [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)] if sigmoid: model4 += [nn.Sigmoid()] self.model4 = nn.Sequential(*model4) def forward(self, x, cond=None): out = self.model0(x) out = self.model1(out) out = self.model2(out) out = self.model3(out) out = self.model4(out) return out # πŸ”§ Load the models # Make sure you have 'model.pth' and 'model2.pth' in the same directory try: model1 = Generator(3, 1, 3) model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))) model1.eval() model2 = Generator(3, 1, 3) model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu'))) model2.eval() except FileNotFoundError: print("Warning: Model files 'model.pth' or 'model2.pth' not found. The application will not run correctly.") model1, model2 = None, None # ✨ Function to apply the selected filter def apply_filter(line_img, filter_name, original_img): if filter_name == "Standard": return line_img # Convert line drawing to grayscale for most operations line_img_l = line_img.convert('L') # --- Standard Image Filters --- if filter_name == "Invert": return ImageOps.invert(line_img_l) if filter_name == "Blur": return line_img.filter(ImageFilter.GaussianBlur(radius=3)) if filter_name == "Sharpen": return line_img.filter(ImageFilter.SHARPEN) if filter_name == "Contour": return line_img_l.filter(ImageFilter.CONTOUR) if filter_name == "Detail": return line_img.filter(ImageFilter.DETAIL) if filter_name == "EdgeEnhance": return line_img_l.filter(ImageFilter.EDGE_ENHANCE) if filter_name == "EdgeEnhanceMore": return line_img_l.filter(ImageFilter.EDGE_ENHANCE_MORE) if filter_name == "Emboss": return line_img_l.filter(ImageFilter.EMBOSS) if filter_name == "FindEdges": return line_img_l.filter(ImageFilter.FIND_EDGES) if filter_name == "Smooth": return line_img.filter(ImageFilter.SMOOTH) if filter_name == "SmoothMore": return line_img.filter(ImageFilter.SMOOTH_MORE) # --- Tonal Adjustments --- if filter_name == "Solarize": return ImageOps.solarize(line_img_l) if filter_name == "Posterize1": return ImageOps.posterize(line_img_l, 1) if filter_name == "Posterize2": return ImageOps.posterize(line_img_l, 2) if filter_name == "Posterize3": return ImageOps.posterize(line_img_l, 3) if filter_name == "Posterize4": return ImageOps.posterize(line_img_l, 4) if filter_name == "Equalize": return ImageOps.equalize(line_img_l) if filter_name == "AutoContrast": return ImageOps.autocontrast(line_img_l) if filter_name == "Binary": return line_img_l.convert('1') # --- Morphological Operations (Thick/Thin) --- if filter_name == "Thick1": return line_img_l.filter(ImageFilter.MinFilter(3)) if filter_name == "Thick2": return line_img_l.filter(ImageFilter.MinFilter(5)) if filter_name == "Thick3": return line_img_l.filter(ImageFilter.MinFilter(7)) if filter_name == "Thin1": return line_img_l.filter(ImageFilter.MaxFilter(3)) if filter_name == "Thin2": return line_img_l.filter(ImageFilter.MaxFilter(5)) if filter_name == "Thin3": return line_img_l.filter(ImageFilter.MaxFilter(7)) # --- Colorization (On White Background) --- colors_on_white = {"RedOnWhite": "red", "OrangeOnWhite": "orange", "YellowOnWhite": "yellow", "GreenOnWhite": "green", "BlueOnWhite": "blue", "PurpleOnWhite": "purple", "PinkOnWhite": "pink", "CyanOnWhite": "cyan", "MagentaOnWhite": "magenta", "BrownOnWhite": "brown", "GrayOnWhite": "gray"} if filter_name in colors_on_white: return ImageOps.colorize(line_img_l, black=colors_on_white[filter_name], white="white") # --- Colorization (On Black Background) --- colors_on_black = {"WhiteOnBlack": "white", "RedOnBlack": "red", "OrangeOnBlack": "orange", "YellowOnBlack": "yellow", "GreenOnBlack": "green", "BlueOnBlack": "blue", "PurpleOnBlack": "purple", "PinkOnBlack": "pink", "CyanOnBlack": "cyan", "MagentaOnBlack": "magenta", "BrownOnBlack": "brown", "GrayOnBlack": "gray"} if filter_name in colors_on_black: return ImageOps.colorize(line_img_l, black=colors_on_black[filter_name], white="black") # --- Blending Modes with Original Image --- line_img_rgb = line_img.convert('RGB') if filter_name == "Multiply": return ImageChops.multiply(original_img, line_img_rgb) if filter_name == "Screen": return ImageChops.screen(original_img, line_img_rgb) if filter_name == "Overlay": return ImageChops.overlay(original_img, line_img_rgb) if filter_name == "Add": return ImageChops.add(original_img, line_img_rgb) if filter_name == "Subtract": return ImageChops.subtract(original_img, line_img_rgb) if filter_name == "Difference": return ImageChops.difference(original_img, line_img_rgb) if filter_name == "Darker": return ImageChops.darker(original_img, line_img_rgb) if filter_name == "Lighter": return ImageChops.lighter(original_img, line_img_rgb) if filter_name == "SoftLight": return ImageChops.soft_light(original_img, line_img_rgb) if filter_name == "HardLight": return ImageChops.hard_light(original_img, line_img_rgb) # --- Texture --- if filter_name == "Noise": img_array = np.array(line_img_l.convert('L')) noise = np.random.randint(-20, 20, img_array.shape, dtype='int16') noisy_array = np.clip(img_array.astype('int16') + noise, 0, 255).astype('uint8') return Image.fromarray(noisy_array) return line_img # Default fallback # πŸ–ΌοΈ Main function to process the image def predict(input_img_path, line_style, filter_choice): if not model1 or not model2: raise gr.Error("Models are not loaded. Please check for 'model.pth' and 'model2.pth'.") # Extract the filter name from the dropdown choice (e.g., "πŸ“„ Standard" -> "Standard") filter_name = filter_choice.split(" ", 1)[1] original_img = Image.open(input_img_path).convert('RGB') original_size = original_img.size transform = transforms.Compose([ transforms.Resize(256, transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) input_tensor = transform(original_img).unsqueeze(0) with torch.no_grad(): if line_style == 'Simple Lines': output = model2(input_tensor) else: # Complex Lines output = model1(input_tensor) # Convert tensor to low-res PIL image line_drawing_low_res = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1)) # Resize the line drawing back to the original image size *before* applying filters line_drawing_full_res = line_drawing_low_res.resize(original_size, Image.Resampling.BICUBIC) # Apply the selected filter final_image = apply_filter(line_drawing_full_res, filter_name, original_img) return final_image # πŸš€ Setup and launch the Gradio interface title = "πŸ–ŒοΈ Image to Line Art with Creative Filters" description = "Upload an image, choose a line style (Complex or Simple), and select a filter from the dropdown to transform your picture into unique line art." # Generate dropdown choices with emojis filter_choices = [f"{emoji} {name}" for name, emoji in FILTERS.items()] # Dynamically generate examples from images in the current directory examples = [] image_dir = '.' if os.path.exists(image_dir): image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] if image_files: # Pick the first image found and create examples with a few interesting filters example_image = image_files[0] examples.append([example_image, 'Simple Lines', 'πŸ—ΊοΈ Contour']) examples.append([example_image, 'Complex Lines', 'πŸ”΅βš« BlueOnBlack']) examples.append([example_image, 'Simple Lines', 'βœ–οΈ Multiply']) iface = gr.Interface( fn=predict, inputs=[ gr.Image(type='filepath', label="Upload Image"), gr.Radio(['Complex Lines', 'Simple Lines'], label='Line Style', value='Simple Lines'), gr.Dropdown(filter_choices, label="Filter", value=filter_choices[0]) ], outputs=gr.Image(type="pil", label="Filtered Line Art"), title=title, description=description, examples=examples, allow_flagging='never' ) if __name__ == "__main__": iface.launch()