|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import gradio as gr |
|
from PIL import Image, ImageFilter, ImageOps, ImageChops |
|
import torchvision.transforms as transforms |
|
import os |
|
import random |
|
import pathlib |
|
|
|
|
|
|
|
output_dir = "outputs" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff"] |
|
|
|
|
|
FILTERS = { |
|
"Standard": "📄", "Invert": "⚫⚪", "Blur": "🌫️", "Sharpen": "🔪", "Contour": "🗺️", |
|
"Detail": "🔍", "EdgeEnhance": "📏", "EdgeEnhanceMore": "📐", "Emboss": "🏞️", |
|
"FindEdges": "🕵️", "Smooth": "🌊", "SmoothMore": "💧", "Solarize": "☀️", |
|
"Posterize1": "🖼️1", "Posterize2": "🖼️2", "Posterize3": "🖼️3", "Posterize4": "🖼️4", |
|
"Equalize": "⚖️", "AutoContrast": "🔧", "Thick1": "💪1", "Thick2": "💪2", "Thick3": "💪3", |
|
"Thin1": "🏃1", "Thin2": "🏃2", "Thin3": "🏃3", "RedOnWhite": "🔴", "OrangeOnWhite": "🟠", |
|
"YellowOnWhite": "🟡", "GreenOnWhite": "🟢", "BlueOnWhite": "🔵", "PurpleOnWhite": "🟣", |
|
"PinkOnWhite": "🌸", "CyanOnWhite": "🩵", "MagentaOnWhite": "🟪", "BrownOnWhite": "🤎", |
|
"GrayOnWhite": "🩶", "WhiteOnBlack": "⚪", "RedOnBlack": "🔴⚫", "OrangeOnBlack": "🟠⚫", |
|
"YellowOnBlack": "🟡⚫", "GreenOnBlack": "🟢⚫", "BlueOnBlack": "🔵⚫", "PurpleOnBlack": "🟣⚫", |
|
"PinkOnBlack": "🌸⚫", "CyanOnBlack": "🩵⚫", "MagentaOnBlack": "🟪⚫", "BrownOnBlack": "🤎⚫", |
|
"GrayOnBlack": "🩶⚫", "Multiply": "✖️", "Screen": "🖥️", "Overlay": "🔲", "Add": "➕", |
|
"Subtract": "➖", "Difference": "≠", "Darker": "🌑", "Lighter": "🌕", "SoftLight": "💡", |
|
"HardLight": "🔦", "Binary": "🌓", "Noise": "❄️" |
|
} |
|
|
|
|
|
norm_layer = nn.InstanceNorm2d |
|
class ResidualBlock(nn.Module): |
|
def __init__(self, in_features): |
|
super(ResidualBlock, self).__init__() |
|
conv_block = [ nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features), nn.ReLU(inplace=True), |
|
nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features) ] |
|
self.conv_block = nn.Sequential(*conv_block) |
|
def forward(self, x): return x + self.conv_block(x) |
|
|
|
class Generator(nn.Module): |
|
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): |
|
super(Generator, self).__init__() |
|
model0 = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True) ] |
|
self.model0 = nn.Sequential(*model0) |
|
model1, in_features, out_features = [], 64, 128 |
|
for _ in range(2): |
|
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] |
|
in_features = out_features; out_features = in_features*2 |
|
self.model1 = nn.Sequential(*model1) |
|
model2 = [ResidualBlock(in_features) for _ in range(n_residual_blocks)] |
|
self.model2 = nn.Sequential(*model2) |
|
model3, out_features = [], in_features//2 |
|
for _ in range(2): |
|
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] |
|
in_features = out_features; out_features = in_features//2 |
|
self.model3 = nn.Sequential(*model3) |
|
model4 = [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)] |
|
if sigmoid: model4 += [nn.Sigmoid()] |
|
self.model4 = nn.Sequential(*model4) |
|
def forward(self, x, cond=None): return self.model4(self.model3(self.model2(self.model1(self.model0(x))))) |
|
|
|
|
|
try: |
|
model1 = Generator(3, 1, 3); model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))); model1.eval() |
|
model2 = Generator(3, 1, 3); model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu'))); model2.eval() |
|
except FileNotFoundError: |
|
print("⚠️ Warning: Model files 'model.pth' or 'model2.pth' not found. The application will not run correctly.") |
|
model1, model2 = None, None |
|
|
|
|
|
def apply_filter(line_img, filter_name, original_img): |
|
if filter_name == "Standard": return line_img |
|
line_img_l = line_img.convert('L') |
|
if filter_name == "Invert": return ImageOps.invert(line_img_l) |
|
if filter_name == "Blur": return line_img.filter(ImageFilter.GaussianBlur(radius=3)) |
|
if filter_name == "Sharpen": return line_img.filter(ImageFilter.SHARPEN) |
|
if filter_name == "Contour": return line_img_l.filter(ImageFilter.CONTOUR) |
|
if filter_name == "Detail": return line_img.filter(ImageFilter.DETAIL) |
|
if filter_name == "EdgeEnhance": return line_img_l.filter(ImageFilter.EDGE_ENHANCE) |
|
if filter_name == "EdgeEnhanceMore": return line_img_l.filter(ImageFilter.EDGE_ENHANCE_MORE) |
|
if filter_name == "Emboss": return line_img_l.filter(ImageFilter.EMBOSS) |
|
if filter_name == "FindEdges": return line_img_l.filter(ImageFilter.FIND_EDGES) |
|
if filter_name == "Smooth": return line_img.filter(ImageFilter.SMOOTH) |
|
if filter_name == "SmoothMore": return line_img.filter(ImageFilter.SMOOTH_MORE) |
|
if filter_name == "Solarize": return ImageOps.solarize(line_img_l) |
|
if filter_name.startswith("Posterize"): return ImageOps.posterize(line_img_l, int(filter_name[-1])) |
|
if filter_name == "Equalize": return ImageOps.equalize(line_img_l) |
|
if filter_name == "AutoContrast": return ImageOps.autocontrast(line_img_l) |
|
if filter_name == "Binary": return line_img_l.convert('1') |
|
if filter_name.startswith("Thick"): return line_img_l.filter(ImageFilter.MinFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7))) |
|
if filter_name.startswith("Thin"): return line_img_l.filter(ImageFilter.MaxFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7))) |
|
colors_on_white = {"RedOnWhite": "red", "OrangeOnWhite": "orange", "YellowOnWhite": "yellow", "GreenOnWhite": "green", "BlueOnWhite": "blue", "PurpleOnWhite": "purple", "PinkOnWhite": "pink", "CyanOnWhite": "cyan", "MagentaOnWhite": "magenta", "BrownOnWhite": "brown", "GrayOnWhite": "gray"} |
|
if filter_name in colors_on_white: return ImageOps.colorize(line_img_l, black=colors_on_white[filter_name], white="white") |
|
colors_on_black = {"WhiteOnBlack": "white", "RedOnBlack": "red", "OrangeOnBlack": "orange", "YellowOnBlack": "yellow", "GreenOnBlack": "green", "BlueOnBlack": "blue", "PurpleOnBlack": "purple", "PinkOnBlack": "pink", "CyanOnBlack": "cyan", "MagentaOnBlack": "magenta", "BrownOnBlack": "brown", "GrayOnBlack": "gray"} |
|
if filter_name in colors_on_black: return ImageOps.colorize(line_img_l, black=colors_on_black[filter_name], white="black") |
|
line_img_rgb = line_img.convert('RGB') |
|
blend_ops = {"Multiply": ImageChops.multiply, "Screen": ImageChops.screen, "Overlay": ImageChops.overlay, "Add": ImageChops.add, "Subtract": ImageChops.subtract, "Difference": ImageChops.difference, "Darker": ImageChops.darker, "Lighter": ImageChops.lighter, "SoftLight": ImageChops.soft_light, "HardLight": ImageChops.hard_light} |
|
if filter_name in blend_ops: return blend_ops[filter_name](original_img, line_img_rgb) |
|
if filter_name == "Noise": |
|
img_array = np.array(line_img_l) |
|
noise = np.random.randint(-20, 20, img_array.shape, dtype='int16') |
|
noisy_array = np.clip(img_array.astype('int16') + noise, 0, 255).astype('uint8') |
|
return Image.fromarray(noisy_array) |
|
return line_img |
|
|
|
|
|
def process_image(input_img_path, line_style, filter_choice, gallery_state): |
|
if not model1 or not model2: |
|
raise gr.Error("Models are not loaded. Please check for 'model.pth' and 'model2.pth'.") |
|
if not input_img_path: |
|
raise gr.Error("Please select an image from the file explorer first.") |
|
|
|
filter_name = filter_choice.split(" ", 1)[1] |
|
original_img = Image.open(input_img_path).convert('RGB') |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256, transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
input_tensor = transform(original_img).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model2(input_tensor) if line_style == 'Simple Lines' else model1(input_tensor) |
|
|
|
line_drawing_low_res = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1)) |
|
line_drawing_full_res = line_drawing_low_res.resize(original_img.size, Image.Resampling.BICUBIC) |
|
|
|
final_image = apply_filter(line_drawing_full_res, filter_name, original_img) |
|
|
|
|
|
base_name = pathlib.Path(input_img_path).stem |
|
output_filename = f"{base_name}_{filter_name}.png" |
|
output_filepath = os.path.join(output_dir, output_filename) |
|
final_image.save(output_filepath) |
|
|
|
|
|
gallery_state.insert(0, output_filepath) |
|
|
|
|
|
return final_image, gallery_state |
|
|
|
|
|
title = "🖌️ Image to Line Art with Creative Filters" |
|
description = "1. Browse and select an image using the file explorer. 2. Choose a line style. 3. Pick a filter. Your results will be saved to the 'outputs' folder and appear in the gallery below." |
|
|
|
|
|
def generate_examples(): |
|
example_images = [f"{i:02d}.jpeg" for i in range(1, 11)] |
|
|
|
valid_example_images = [img for img in example_images if os.path.exists(img)] |
|
|
|
if not valid_example_images: |
|
print("⚠️ Warning: No example images ('01.jpeg' through '10.jpeg') found. Examples will be empty.") |
|
return [] |
|
|
|
examples = [] |
|
for name, emoji in FILTERS.items(): |
|
filter_choice = f"{emoji} {name}" |
|
random_image = random.choice(valid_example_images) |
|
line_style = random.choice(['Simple Lines', 'Complex Lines']) |
|
examples.append([random_image, line_style, filter_choice]) |
|
|
|
random.shuffle(examples) |
|
return examples |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>") |
|
gr.Markdown(description) |
|
|
|
|
|
gallery_state = gr.State(value=[]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### 1. Select an Image") |
|
|
|
input_image_path = gr.FileExplorer( |
|
root=".", |
|
glob=f"**/*[{''.join(ext[1:] for ext in IMAGE_EXTENSIONS)}]", |
|
label="Browse Your Images", |
|
height=400 |
|
) |
|
gr.Markdown("### 2. Choose a Line Style") |
|
line_style_radio = gr.Radio( |
|
['Complex Lines', 'Simple Lines'], |
|
label="Line Style", |
|
value='Simple Lines' |
|
) |
|
|
|
with gr.Column(scale=3): |
|
gr.Markdown("### 3. Pick a Filter") |
|
filter_buttons = [gr.Button(value=f"{emoji} {name}") for name, emoji in FILTERS.items()] |
|
|
|
|
|
selected_filter = gr.Radio( |
|
[b.value for b in filter_buttons], |
|
label="Selected Filter", |
|
visible=False, |
|
value=filter_buttons[0].value |
|
) |
|
|
|
gr.Markdown("### 4. Result") |
|
main_output_image = gr.Image(type="pil", label="Latest Result") |
|
|
|
with gr.Row(): |
|
gr.Markdown("---") |
|
|
|
with gr.Row(): |
|
|
|
gr.Examples( |
|
examples=generate_examples(), |
|
inputs=[input_image_path, line_style_radio, selected_filter], |
|
label="✨ Click an Example to Start", |
|
examples_per_page=10 |
|
) |
|
|
|
with gr.Row(): |
|
gr.Markdown("## 🖼️ Result Gallery (Saved in 'outputs' folder)") |
|
gallery_output = gr.Gallery(label="Your Generated Images", height=600, columns=5) |
|
|
|
|
|
def handle_filter_click(btn_value, current_path, style, state): |
|
|
|
new_main_img, new_state = process_image(current_path, style, btn_value, state) |
|
|
|
return btn_value, new_main_img, new_state |
|
|
|
for btn in filter_buttons: |
|
btn.click( |
|
fn=handle_filter_click, |
|
inputs=[btn, input_image_path, line_style_radio, gallery_state], |
|
outputs=[selected_filter, main_output_image, gallery_state] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|