|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import gradio as gr |
|
from PIL import Image, ImageFilter, ImageOps, ImageChops |
|
import torchvision.transforms as transforms |
|
import os |
|
import pathlib |
|
|
|
|
|
output_dir = "outputs" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
FILTERS = { |
|
"Standard": "📄", "Invert": "⚫⚪", "Blur": "🌫️", "Sharpen": "🔪", "Contour": "🗺️", |
|
"Detail": "🔍", "EdgeEnhance": "📏", "EdgeEnhanceMore": "📐", "Emboss": "🏞️", |
|
"FindEdges": "🕵️", "Smooth": "🌊", "SmoothMore": "💧", "Solarize": "☀️", |
|
"Posterize1": "🖼️1", "Posterize2": "🖼️2", "Posterize3": "🖼️3", "Posterize4": "🖼️4", |
|
"Equalize": "⚖️", "AutoContrast": "🔧", "Thick1": "💪1", "Thick2": "💪2", "Thick3": "💪3", |
|
"Thin1": "🏃1", "Thin2": "🏃2", "Thin3": "🏃3", "RedOnWhite": "🔴", "OrangeOnWhite": "🟠", |
|
"YellowOnWhite": "🟡", "GreenOnWhite": "🟢", "BlueOnWhite": "🔵", "PurpleOnWhite": "🟣", |
|
"PinkOnWhite": "🌸", "CyanOnWhite": "🩵", "MagentaOnWhite": "🟪", "BrownOnWhite": "🤎", |
|
"GrayOnWhite": "🩶", "WhiteOnBlack": "⚪", "RedOnBlack": "🔴⚫", "OrangeOnBlack": "🟠⚫", |
|
"YellowOnBlack": "🟡⚫", "GreenOnBlack": "🟢⚫", "BlueOnBlack": "🔵⚫", "PurpleOnBlack": "🟣⚫", |
|
"PinkOnBlack": "🌸⚫", "CyanOnBlack": "🩵⚫", "MagentaOnBlack": "🟪⚫", "BrownOnBlack": "🤎⚫", |
|
"GrayOnBlack": "🩶⚫", "Multiply": "✖️", "Screen": "🖥️", "Overlay": "🔲", "Add": "➕", |
|
"Subtract": "➖", "Difference": "≠", "Darker": "🌑", "Lighter": "🌕", "SoftLight": "💡", |
|
"HardLight": "🔦", "Binary": "🌓", "Noise": "❄️" |
|
} |
|
|
|
|
|
norm_layer = nn.InstanceNorm2d |
|
class ResidualBlock(nn.Module): |
|
def __init__(self, in_features): |
|
super(ResidualBlock, self).__init__() |
|
conv_block = [ nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features), nn.ReLU(inplace=True), |
|
nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features) ] |
|
self.conv_block = nn.Sequential(*conv_block) |
|
def forward(self, x): return x + self.conv_block(x) |
|
|
|
class Generator(nn.Module): |
|
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): |
|
super(Generator, self).__init__() |
|
model0 = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True) ] |
|
self.model0 = nn.Sequential(*model0) |
|
model1, in_features, out_features = [], 64, 128 |
|
for _ in range(2): |
|
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] |
|
in_features = out_features; out_features = in_features*2 |
|
self.model1 = nn.Sequential(*model1) |
|
model2 = [ResidualBlock(in_features) for _ in range(n_residual_blocks)] |
|
self.model2 = nn.Sequential(*model2) |
|
model3, out_features = [], in_features//2 |
|
for _ in range(2): |
|
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] |
|
in_features = out_features; out_features = in_features//2 |
|
self.model3 = nn.Sequential(*model3) |
|
model4 = [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)] |
|
if sigmoid: model4 += [nn.Sigmoid()] |
|
self.model4 = nn.Sequential(*model4) |
|
def forward(self, x, cond=None): return self.model4(self.model3(self.model2(self.model1(self.model0(x))))) |
|
|
|
|
|
try: |
|
model1 = Generator(3, 1, 3); model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))); model1.eval() |
|
model2 = Generator(3, 1, 3); model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu'))); model2.eval() |
|
except FileNotFoundError: |
|
print("⚠️ Warning: Model files 'model.pth' or 'model2.pth' not found. The application will not run correctly.") |
|
model1, model2 = None, None |
|
|
|
|
|
def apply_filter(line_img, filter_name, original_img): |
|
if filter_name == "Standard": return line_img |
|
line_img_l = line_img.convert('L') |
|
if filter_name == "Invert": return ImageOps.invert(line_img_l) |
|
if filter_name == "Blur": return line_img.filter(ImageFilter.GaussianBlur(radius=3)) |
|
if filter_name == "Sharpen": return line_img.filter(ImageFilter.SHARPEN) |
|
if filter_name == "Contour": return line_img_l.filter(ImageFilter.CONTOUR) |
|
if filter_name == "Detail": return line_img.filter(ImageFilter.DETAIL) |
|
if filter_name == "EdgeEnhance": return line_img_l.filter(ImageFilter.EDGE_ENHANCE) |
|
if filter_name == "EdgeEnhanceMore": return line_img_l.filter(ImageFilter.EDGE_ENHANCE_MORE) |
|
if filter_name == "Emboss": return line_img_l.filter(ImageFilter.EMBOSS) |
|
if filter_name == "FindEdges": return line_img_l.filter(ImageFilter.FIND_EDGES) |
|
if filter_name == "Smooth": return line_img.filter(ImageFilter.SMOOTH) |
|
if filter_name == "SmoothMore": return line_img.filter(ImageFilter.SMOOTH_MORE) |
|
if filter_name == "Solarize": return ImageOps.solarize(line_img_l) |
|
if filter_name.startswith("Posterize"): return ImageOps.posterize(line_img_l, int(filter_name[-1])) |
|
if filter_name == "Equalize": return ImageOps.equalize(line_img_l) |
|
if filter_name == "AutoContrast": return ImageOps.autocontrast(line_img_l) |
|
if filter_name == "Binary": return line_img_l.convert('1') |
|
if filter_name.startswith("Thick"): return line_img_l.filter(ImageFilter.MinFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7))) |
|
if filter_name.startswith("Thin"): return line_img_l.filter(ImageFilter.MaxFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7))) |
|
colors_on_white = {"RedOnWhite": "red", "OrangeOnWhite": "orange", "YellowOnWhite": "yellow", "GreenOnWhite": "green", "BlueOnWhite": "blue", "PurpleOnWhite": "purple", "PinkOnWhite": "pink", "CyanOnWhite": "cyan", "MagentaOnWhite": "magenta", "BrownOnWhite": "brown", "GrayOnWhite": "gray"} |
|
if filter_name in colors_on_white: return ImageOps.colorize(line_img_l, black=colors_on_white[filter_name], white="white") |
|
colors_on_black = {"WhiteOnBlack": "white", "RedOnBlack": "red", "OrangeOnBlack": "orange", "YellowOnBlack": "yellow", "GreenOnBlack": "green", "BlueOnBlack": "blue", "PurpleOnBlack": "purple", "PinkOnBlack": "pink", "CyanOnBlack": "cyan", "MagentaOnBlack": "magenta", "BrownOnBlack": "brown", "GrayOnBlack": "gray"} |
|
if filter_name in colors_on_black: return ImageOps.colorize(line_img_l, black=colors_on_black[filter_name], white="black") |
|
line_img_rgb = line_img.convert('RGB') |
|
blend_ops = {"Multiply": ImageChops.multiply, "Screen": ImageChops.screen, "Overlay": ImageChops.overlay, "Add": ImageChops.add, "Subtract": ImageChops.subtract, "Difference": ImageChops.difference, "Darker": ImageChops.darker, "Lighter": ImageChops.lighter, "SoftLight": ImageChops.soft_light, "HardLight": ImageChops.hard_light} |
|
if filter_name in blend_ops: return blend_ops[filter_name](original_img, line_img_rgb) |
|
if filter_name == "Noise": |
|
img_array = np.array(line_img_l) |
|
noise = np.random.randint(-20, 20, img_array.shape, dtype='int16') |
|
noisy_array = np.clip(img_array.astype('int16') + noise, 0, 255).astype('uint8') |
|
return Image.fromarray(noisy_array) |
|
return line_img |
|
|
|
|
|
def predict(input_img_path, line_style, filter_choice): |
|
if not model1 or not model2: |
|
raise gr.Error("Models are not loaded. Please check for 'model.pth' and 'model2.pth'.") |
|
|
|
filter_name = filter_choice.split(" ", 1)[1] |
|
original_img = Image.open(input_img_path).convert('RGB') |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256, transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
input_tensor = transform(original_img).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model2(input_tensor) if line_style == 'Simple Lines' else model1(input_tensor) |
|
|
|
line_drawing_low_res = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1)) |
|
line_drawing_full_res = line_drawing_low_res.resize(original_img.size, Image.Resampling.BICUBIC) |
|
|
|
final_image = apply_filter(line_drawing_full_res, filter_name, original_img) |
|
|
|
|
|
base_name = pathlib.Path(input_img_path).stem |
|
output_filename = f"{base_name}_{filter_name}.png" |
|
output_filepath = os.path.join(output_dir, output_filename) |
|
final_image.save(output_filepath) |
|
|
|
return final_image |
|
|
|
|
|
title = "🖌️ Image to Line Art with Creative Filters" |
|
description = "Upload an image, choose a line style (Complex or Simple), and select a filter from the dropdown to transform your picture. Results are saved in the 'outputs' folder." |
|
|
|
filter_choices = [f"{emoji} {name}" for name, emoji in FILTERS.items()] |
|
|
|
|
|
examples = [] |
|
example_images = [f"{i:02d}.jpeg" for i in range(1, 11)] |
|
|
|
demo_filters = ["🗺️ Contour", "🔵⚫ BlueOnBlack", "✖️ Multiply", "🏞️ Emboss", "🔪 Sharpen", "❄️ Noise"] |
|
|
|
|
|
for i, img_file in enumerate(example_images): |
|
if os.path.exists(img_file): |
|
|
|
chosen_filter = demo_filters[i % len(demo_filters)] |
|
examples.append([img_file, 'Simple Lines', chosen_filter]) |
|
|
|
if not examples: |
|
print("⚠️ Warning: No example images ('01.jpeg' to '10.jpeg') found. Examples will be empty.") |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Image(type='filepath', label="Upload Image"), |
|
gr.Radio(['Complex Lines', 'Simple Lines'], label='Line Style', value='Simple Lines'), |
|
gr.Dropdown(filter_choices, label="Filter", value=filter_choices[0]) |
|
], |
|
outputs=gr.Image(type="pil", label="Filtered Line Art"), |
|
title=title, |
|
description=description, |
|
examples=examples, |
|
allow_flagging='never' |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |