awacke1's picture
Update app.py
3f97e41 verified
raw
history blame
10.6 kB
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image, ImageFilter, ImageOps, ImageChops
import torchvision.transforms as transforms
import os
import pathlib
# --- ⚙️ Configuration ---
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
# --- 🎨 Filters ---
FILTERS = {
"Standard": "📄", "Invert": "⚫⚪", "Blur": "🌫️", "Sharpen": "🔪", "Contour": "🗺️",
"Detail": "🔍", "EdgeEnhance": "📏", "EdgeEnhanceMore": "📐", "Emboss": "🏞️",
"FindEdges": "🕵️", "Smooth": "🌊", "SmoothMore": "💧", "Solarize": "☀️",
"Posterize1": "🖼️1", "Posterize2": "🖼️2", "Posterize3": "🖼️3", "Posterize4": "🖼️4",
"Equalize": "⚖️", "AutoContrast": "🔧", "Thick1": "💪1", "Thick2": "💪2", "Thick3": "💪3",
"Thin1": "🏃1", "Thin2": "🏃2", "Thin3": "🏃3", "RedOnWhite": "🔴", "OrangeOnWhite": "🟠",
"YellowOnWhite": "🟡", "GreenOnWhite": "🟢", "BlueOnWhite": "🔵", "PurpleOnWhite": "🟣",
"PinkOnWhite": "🌸", "CyanOnWhite": "🩵", "MagentaOnWhite": "🟪", "BrownOnWhite": "🤎",
"GrayOnWhite": "🩶", "WhiteOnBlack": "⚪", "RedOnBlack": "🔴⚫", "OrangeOnBlack": "🟠⚫",
"YellowOnBlack": "🟡⚫", "GreenOnBlack": "🟢⚫", "BlueOnBlack": "🔵⚫", "PurpleOnBlack": "🟣⚫",
"PinkOnBlack": "🌸⚫", "CyanOnBlack": "🩵⚫", "MagentaOnBlack": "🟪⚫", "BrownOnBlack": "🤎⚫",
"GrayOnBlack": "🩶⚫", "Multiply": "✖️", "Screen": "🖥️", "Overlay": "🔲", "Add": "➕",
"Subtract": "➖", "Difference": "≠", "Darker": "🌑", "Lighter": "🌕", "SoftLight": "💡",
"HardLight": "🔦", "Binary": "🌓", "Noise": "❄️"
}
# --- 🧠 Neural Network Model (Unchanged) ---
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [ nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features), nn.ReLU(inplace=True),
nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features) ]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x): return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
model0 = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True) ]
self.model0 = nn.Sequential(*model0)
model1, in_features, out_features = [], 64, 128
for _ in range(2):
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ]
in_features = out_features; out_features = in_features*2
self.model1 = nn.Sequential(*model1)
model2 = [ResidualBlock(in_features) for _ in range(n_residual_blocks)]
self.model2 = nn.Sequential(*model2)
model3, out_features = [], in_features//2
for _ in range(2):
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ]
in_features = out_features; out_features = in_features//2
self.model3 = nn.Sequential(*model3)
model4 = [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
if sigmoid: model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None): return self.model4(self.model3(self.model2(self.model1(self.model0(x)))))
# --- 🔧 Model Loading ---
try:
model1 = Generator(3, 1, 3); model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))); model1.eval()
model2 = Generator(3, 1, 3); model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu'))); model2.eval()
except FileNotFoundError:
print("⚠️ Warning: Model files 'model.pth' or 'model2.pth' not found. The application will not run correctly.")
model1, model2 = None, None
# --- ✨ Filter Application Logic (Unchanged) ---
def apply_filter(line_img, filter_name, original_img):
if filter_name == "Standard": return line_img
line_img_l = line_img.convert('L')
if filter_name == "Invert": return ImageOps.invert(line_img_l)
if filter_name == "Blur": return line_img.filter(ImageFilter.GaussianBlur(radius=3))
if filter_name == "Sharpen": return line_img.filter(ImageFilter.SHARPEN)
if filter_name == "Contour": return line_img_l.filter(ImageFilter.CONTOUR)
if filter_name == "Detail": return line_img.filter(ImageFilter.DETAIL)
if filter_name == "EdgeEnhance": return line_img_l.filter(ImageFilter.EDGE_ENHANCE)
if filter_name == "EdgeEnhanceMore": return line_img_l.filter(ImageFilter.EDGE_ENHANCE_MORE)
if filter_name == "Emboss": return line_img_l.filter(ImageFilter.EMBOSS)
if filter_name == "FindEdges": return line_img_l.filter(ImageFilter.FIND_EDGES)
if filter_name == "Smooth": return line_img.filter(ImageFilter.SMOOTH)
if filter_name == "SmoothMore": return line_img.filter(ImageFilter.SMOOTH_MORE)
if filter_name == "Solarize": return ImageOps.solarize(line_img_l)
if filter_name.startswith("Posterize"): return ImageOps.posterize(line_img_l, int(filter_name[-1]))
if filter_name == "Equalize": return ImageOps.equalize(line_img_l)
if filter_name == "AutoContrast": return ImageOps.autocontrast(line_img_l)
if filter_name == "Binary": return line_img_l.convert('1')
if filter_name.startswith("Thick"): return line_img_l.filter(ImageFilter.MinFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7)))
if filter_name.startswith("Thin"): return line_img_l.filter(ImageFilter.MaxFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7)))
colors_on_white = {"RedOnWhite": "red", "OrangeOnWhite": "orange", "YellowOnWhite": "yellow", "GreenOnWhite": "green", "BlueOnWhite": "blue", "PurpleOnWhite": "purple", "PinkOnWhite": "pink", "CyanOnWhite": "cyan", "MagentaOnWhite": "magenta", "BrownOnWhite": "brown", "GrayOnWhite": "gray"}
if filter_name in colors_on_white: return ImageOps.colorize(line_img_l, black=colors_on_white[filter_name], white="white")
colors_on_black = {"WhiteOnBlack": "white", "RedOnBlack": "red", "OrangeOnBlack": "orange", "YellowOnBlack": "yellow", "GreenOnBlack": "green", "BlueOnBlack": "blue", "PurpleOnBlack": "purple", "PinkOnBlack": "pink", "CyanOnBlack": "cyan", "MagentaOnBlack": "magenta", "BrownOnBlack": "brown", "GrayOnBlack": "gray"}
if filter_name in colors_on_black: return ImageOps.colorize(line_img_l, black=colors_on_black[filter_name], white="black")
line_img_rgb = line_img.convert('RGB')
blend_ops = {"Multiply": ImageChops.multiply, "Screen": ImageChops.screen, "Overlay": ImageChops.overlay, "Add": ImageChops.add, "Subtract": ImageChops.subtract, "Difference": ImageChops.difference, "Darker": ImageChops.darker, "Lighter": ImageChops.lighter, "SoftLight": ImageChops.soft_light, "HardLight": ImageChops.hard_light}
if filter_name in blend_ops: return blend_ops[filter_name](original_img, line_img_rgb)
if filter_name == "Noise":
img_array = np.array(line_img_l)
noise = np.random.randint(-20, 20, img_array.shape, dtype='int16')
noisy_array = np.clip(img_array.astype('int16') + noise, 0, 255).astype('uint8')
return Image.fromarray(noisy_array)
return line_img
# --- 🖼️ Main Processing Function ---
def predict(input_img_path, line_style, filter_choice):
if not model1 or not model2:
raise gr.Error("Models are not loaded. Please check for 'model.pth' and 'model2.pth'.")
filter_name = filter_choice.split(" ", 1)[1]
original_img = Image.open(input_img_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize(256, transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = transform(original_img).unsqueeze(0)
with torch.no_grad():
output = model2(input_tensor) if line_style == 'Simple Lines' else model1(input_tensor)
line_drawing_low_res = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
line_drawing_full_res = line_drawing_low_res.resize(original_img.size, Image.Resampling.BICUBIC)
final_image = apply_filter(line_drawing_full_res, filter_name, original_img)
# --- 💾 Save the output image ---
base_name = pathlib.Path(input_img_path).stem
output_filename = f"{base_name}_{filter_name}.png"
output_filepath = os.path.join(output_dir, output_filename)
final_image.save(output_filepath)
return final_image
# --- 🚀 Gradio UI Setup ---
title = "🖌️ Image to Line Art with Creative Filters"
description = "Upload an image, choose a line style (Complex or Simple), and select a filter from the dropdown to transform your picture. Results are saved in the 'outputs' folder."
filter_choices = [f"{emoji} {name}" for name, emoji in FILTERS.items()]
# --- ✅ New Curated Examples Section ---
examples = []
example_images = [f"{i:02d}.jpeg" for i in range(1, 11)]
# A selection of 6 interesting filters to demonstrate
demo_filters = ["🗺️ Contour", "🔵⚫ BlueOnBlack", "✖️ Multiply", "🏞️ Emboss", "🔪 Sharpen", "❄️ Noise"]
# Create one example for each of the 10 image files, cycling through the demo filters
for i, img_file in enumerate(example_images):
if os.path.exists(img_file):
# Use modulo to cycle through the 6 demo filters for the 10 images
chosen_filter = demo_filters[i % len(demo_filters)]
examples.append([img_file, 'Simple Lines', chosen_filter])
if not examples:
print("⚠️ Warning: No example images ('01.jpeg' to '10.jpeg') found. Examples will be empty.")
# Reverted to the simpler and more stable gr.Interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type='filepath', label="Upload Image"),
gr.Radio(['Complex Lines', 'Simple Lines'], label='Line Style', value='Simple Lines'),
gr.Dropdown(filter_choices, label="Filter", value=filter_choices[0])
],
outputs=gr.Image(type="pil", label="Filtered Line Art"),
title=title,
description=description,
examples=examples,
allow_flagging='never'
)
if __name__ == "__main__":
iface.launch()