Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| from socket import has_dualstack_ipv6 | |
| import sys | |
| import copy | |
| import traceback | |
| import math | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import torch | |
| import torch.fft | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.cm | |
| import dnnlib | |
| from torch_utils.ops import upfirdn2d | |
| import legacy # pylint: disable=import-error | |
| # ---------------------------------------------------------------------------- | |
| class CapturedException(Exception): | |
| def __init__(self, msg=None): | |
| if msg is None: | |
| _type, value, _traceback = sys.exc_info() | |
| assert value is not None | |
| if isinstance(value, CapturedException): | |
| msg = str(value) | |
| else: | |
| msg = traceback.format_exc() | |
| assert isinstance(msg, str) | |
| super().__init__(msg) | |
| # ---------------------------------------------------------------------------- | |
| class CaptureSuccess(Exception): | |
| def __init__(self, out): | |
| super().__init__() | |
| self.out = out | |
| # ---------------------------------------------------------------------------- | |
| def add_watermark_np(input_image_array, watermark_text="AI Generated"): | |
| image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA") | |
| # Initialize text image | |
| txt = Image.new('RGBA', image.size, (255, 255, 255, 0)) | |
| font = ImageFont.truetype('arial.ttf', round(25/512*image.size[0])) | |
| d = ImageDraw.Draw(txt) | |
| text_width, text_height = font.getsize(watermark_text) | |
| text_position = (image.size[0] - text_width - | |
| 10, image.size[1] - text_height - 10) | |
| # white color with the alpha channel set to semi-transparent | |
| text_color = (255, 255, 255, 128) | |
| # Draw the text onto the text canvas | |
| d.text(text_position, watermark_text, font=font, fill=text_color) | |
| # Combine the image with the watermark | |
| watermarked = Image.alpha_composite(image, txt) | |
| watermarked_array = np.array(watermarked) | |
| return watermarked_array | |
| # ---------------------------------------------------------------------------- | |
| class Renderer: | |
| def __init__(self, disable_timing=False): | |
| self._device = torch.device('cuda' if torch.cuda.is_available( | |
| ) else 'mps' if torch.backends.mps.is_available() else 'cpu') | |
| self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64 | |
| self._pkl_data = dict() # {pkl: dict | CapturedException, ...} | |
| self._networks = dict() # {cache_key: torch.nn.Module, ...} | |
| self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...} | |
| self._cmaps = dict() # {name: torch.Tensor, ...} | |
| self._is_timing = False | |
| if not disable_timing: | |
| self._start_event = torch.cuda.Event(enable_timing=True) | |
| self._end_event = torch.cuda.Event(enable_timing=True) | |
| self._disable_timing = disable_timing | |
| self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...} | |
| def render(self, **args): | |
| if self._disable_timing: | |
| self._is_timing = False | |
| else: | |
| self._start_event.record(torch.cuda.current_stream(self._device)) | |
| self._is_timing = True | |
| res = dnnlib.EasyDict() | |
| try: | |
| init_net = False | |
| if not hasattr(self, 'G'): | |
| init_net = True | |
| if hasattr(self, 'pkl'): | |
| if self.pkl != args['pkl']: | |
| init_net = True | |
| if hasattr(self, 'w_load'): | |
| if self.w_load is not args['w_load']: | |
| init_net = True | |
| if hasattr(self, 'w0_seed'): | |
| if self.w0_seed != args['w0_seed']: | |
| init_net = True | |
| if hasattr(self, 'w_plus'): | |
| if self.w_plus != args['w_plus']: | |
| init_net = True | |
| if args['reset_w']: | |
| init_net = True | |
| res.init_net = init_net | |
| if init_net: | |
| self.init_network(res, **args) | |
| self._render_drag_impl(res, **args) | |
| except: | |
| res.error = CapturedException() | |
| if not self._disable_timing: | |
| self._end_event.record(torch.cuda.current_stream(self._device)) | |
| if 'image' in res: | |
| res.image = self.to_cpu(res.image).detach().numpy() | |
| res.image = add_watermark_np(res.image, 'AI Generated') | |
| if 'stats' in res: | |
| res.stats = self.to_cpu(res.stats).detach().numpy() | |
| if 'error' in res: | |
| res.error = str(res.error) | |
| # if 'stop' in res and res.stop: | |
| if self._is_timing and not self._disable_timing: | |
| self._end_event.synchronize() | |
| res.render_time = self._start_event.elapsed_time( | |
| self._end_event) * 1e-3 | |
| self._is_timing = False | |
| return res | |
| def get_network(self, pkl, key, **tweak_kwargs): | |
| data = self._pkl_data.get(pkl, None) | |
| if data is None: | |
| print(f'Loading "{pkl}"... ', end='', flush=True) | |
| try: | |
| with dnnlib.util.open_url(pkl, verbose=False) as f: | |
| data = legacy.load_network_pkl(f) | |
| print('Done.') | |
| except: | |
| data = CapturedException() | |
| print('Failed!') | |
| self._pkl_data[pkl] = data | |
| self._ignore_timing() | |
| if isinstance(data, CapturedException): | |
| raise data | |
| orig_net = data[key] | |
| cache_key = (orig_net, self._device, tuple( | |
| sorted(tweak_kwargs.items()))) | |
| net = self._networks.get(cache_key, None) | |
| if net is None: | |
| try: | |
| if 'stylegan2' in pkl: | |
| from training.networks_stylegan2 import Generator | |
| elif 'stylegan3' in pkl: | |
| from training.networks_stylegan3 import Generator | |
| elif 'stylegan_human' in pkl: | |
| from stylegan_human.training_scripts.sg2.training.networks import Generator | |
| else: | |
| raise NameError('Cannot infer model type from pkl name!') | |
| print(data[key].init_args) | |
| print(data[key].init_kwargs) | |
| if 'stylegan_human' in pkl: | |
| net = Generator( | |
| *data[key].init_args, **data[key].init_kwargs, square=False, padding=True) | |
| else: | |
| net = Generator(*data[key].init_args, | |
| **data[key].init_kwargs) | |
| net.load_state_dict(data[key].state_dict()) | |
| net.to(self._device) | |
| except: | |
| net = CapturedException() | |
| self._networks[cache_key] = net | |
| self._ignore_timing() | |
| if isinstance(net, CapturedException): | |
| raise net | |
| return net | |
| def _get_pinned_buf(self, ref): | |
| key = (tuple(ref.shape), ref.dtype) | |
| buf = self._pinned_bufs.get(key, None) | |
| if buf is None: | |
| buf = torch.empty(ref.shape, dtype=ref.dtype).pin_memory() | |
| self._pinned_bufs[key] = buf | |
| return buf | |
| def to_device(self, buf): | |
| return self._get_pinned_buf(buf).copy_(buf).to(self._device) | |
| def to_cpu(self, buf): | |
| return self._get_pinned_buf(buf).copy_(buf).clone() | |
| def _ignore_timing(self): | |
| self._is_timing = False | |
| def _apply_cmap(self, x, name='viridis'): | |
| cmap = self._cmaps.get(name, None) | |
| if cmap is None: | |
| cmap = matplotlib.cm.get_cmap(name) | |
| cmap = cmap(np.linspace(0, 1, num=1024), bytes=True)[:, :3] | |
| cmap = self.to_device(torch.from_numpy(cmap)) | |
| self._cmaps[name] = cmap | |
| hi = cmap.shape[0] - 1 | |
| x = (x * hi + 0.5).clamp(0, hi).to(torch.int64) | |
| x = torch.nn.functional.embedding(x, cmap) | |
| return x | |
| def init_network(self, res, | |
| pkl=None, | |
| w0_seed=0, | |
| w_load=None, | |
| w_plus=True, | |
| noise_mode='const', | |
| trunc_psi=0.7, | |
| trunc_cutoff=None, | |
| input_transform=None, | |
| lr=0.001, | |
| **kwargs | |
| ): | |
| # Dig up network details. | |
| self.pkl = pkl | |
| G = self.get_network(pkl, 'G_ema') | |
| self.G = G | |
| res.img_resolution = G.img_resolution | |
| res.num_ws = G.num_ws | |
| res.has_noise = any('noise_const' in name for name, | |
| _buf in G.synthesis.named_buffers()) | |
| res.has_input_transform = ( | |
| hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform')) | |
| res.stop = False | |
| self.lr = lr | |
| # Set input transform. | |
| if res.has_input_transform: | |
| m = np.eye(3) | |
| try: | |
| if input_transform is not None: | |
| m = np.linalg.inv(np.asarray(input_transform)) | |
| except np.linalg.LinAlgError: | |
| res.error = CapturedException() | |
| G.synthesis.input.transform.copy_(torch.from_numpy(m)) | |
| # Generate random latents. | |
| self.w0_seed = w0_seed | |
| self.w_load = w_load | |
| if self.w_load is None: | |
| # Generate random latents. | |
| z = torch.from_numpy(np.random.RandomState(w0_seed).randn( | |
| 1, 512)).to(self._device, dtype=self._dtype) | |
| # Run mapping network. | |
| label = torch.zeros([1, G.c_dim], device=self._device) | |
| w = G.mapping(z, label, truncation_psi=trunc_psi, | |
| truncation_cutoff=trunc_cutoff) | |
| else: | |
| w = self.w_load.clone().to(self._device) | |
| self.w0 = w.detach().clone() | |
| self.w_plus = w_plus | |
| if w_plus: | |
| self.w = w.detach() | |
| else: | |
| self.w = w[:, 0, :].detach() | |
| self.w.requires_grad = True | |
| self.w_optim = torch.optim.Adam([self.w], lr=lr) | |
| self.feat_refs = None | |
| self.points0_pt = None | |
| def set_latent(self, w, trunc_psi, trunc_cutoff): | |
| # label = torch.zeros([1, self.G.c_dim], device=self._device) | |
| # w = self.G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff) | |
| self.w0 = w.detach().clone() | |
| if self.w_plus: | |
| self.w = w.detach() | |
| else: | |
| self.w = w[:, 0, :].detach() | |
| self.w.requires_grad = True | |
| self.w_optim = torch.optim.Adam([self.w], lr=self.lr) | |
| self.feat_refs = None | |
| self.points0_pt = None | |
| def update_lr(self, lr): | |
| del self.w_optim | |
| self.w_optim = torch.optim.Adam([self.w], lr=lr) | |
| print(f'Rebuild optimizer with lr: {lr}') | |
| print(' Remain feat_refs and points0_pt') | |
| def _render_drag_impl(self, res, | |
| points=[], | |
| targets=[], | |
| mask=None, | |
| lambda_mask=10, | |
| reg=0, | |
| feature_idx=5, | |
| r1=3, | |
| r2=12, | |
| random_seed=0, | |
| noise_mode='const', | |
| trunc_psi=0.7, | |
| force_fp32=False, | |
| layer_name=None, | |
| sel_channels=3, | |
| base_channel=0, | |
| img_scale_db=0, | |
| img_normalize=False, | |
| untransform=False, | |
| is_drag=False, | |
| reset=False, | |
| to_pil=False, | |
| **kwargs | |
| ): | |
| try: | |
| G = self.G | |
| ws = self.w | |
| if ws.dim() == 2: | |
| ws = ws.unsqueeze(1).repeat(1, 6, 1) | |
| ws = torch.cat([ws[:, :6, :], self.w0[:, 6:, :]], dim=1) | |
| if hasattr(self, 'points'): | |
| if len(points) != len(self.points): | |
| reset = True | |
| if reset: | |
| self.feat_refs = None | |
| self.points0_pt = None | |
| self.points = points | |
| # Run synthesis network. | |
| label = torch.zeros([1, G.c_dim], device=self._device) | |
| img, feat = G(ws, label, truncation_psi=trunc_psi, | |
| noise_mode=noise_mode, input_is_w=True, return_feature=True) | |
| h, w = G.img_resolution, G.img_resolution | |
| if is_drag: | |
| X = torch.linspace(0, h, h) | |
| Y = torch.linspace(0, w, w) | |
| xx, yy = torch.meshgrid(X, Y) | |
| feat_resize = F.interpolate( | |
| feat[feature_idx], [h, w], mode='bilinear') | |
| if self.feat_refs is None: | |
| self.feat0_resize = F.interpolate( | |
| feat[feature_idx].detach(), [h, w], mode='bilinear') | |
| self.feat_refs = [] | |
| for point in points: | |
| py, px = round(point[0]), round(point[1]) | |
| self.feat_refs.append(self.feat0_resize[:, :, py, px]) | |
| self.points0_pt = torch.Tensor(points).unsqueeze( | |
| 0).to(self._device) # 1, N, 2 | |
| # Point tracking with feature matching | |
| with torch.no_grad(): | |
| for j, point in enumerate(points): | |
| r = round(r2 / 512 * h) | |
| up = max(point[0] - r, 0) | |
| down = min(point[0] + r + 1, h) | |
| left = max(point[1] - r, 0) | |
| right = min(point[1] + r + 1, w) | |
| feat_patch = feat_resize[:, :, up:down, left:right] | |
| L2 = torch.linalg.norm( | |
| feat_patch - self.feat_refs[j].reshape(1, -1, 1, 1), dim=1) | |
| _, idx = torch.min(L2.view(1, -1), -1) | |
| width = right - left | |
| point = [idx.item() // width + up, idx.item() % | |
| width + left] | |
| points[j] = point | |
| res.points = [[point[0], point[1]] for point in points] | |
| # Motion supervision | |
| loss_motion = 0 | |
| res.stop = True | |
| for j, point in enumerate(points): | |
| direction = torch.Tensor( | |
| [targets[j][1] - point[1], targets[j][0] - point[0]]) | |
| if torch.linalg.norm(direction) > max(2 / 512 * h, 2): | |
| res.stop = False | |
| if torch.linalg.norm(direction) > 1: | |
| distance = ( | |
| (xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5 | |
| relis, reljs = torch.where( | |
| distance < round(r1 / 512 * h)) | |
| direction = direction / \ | |
| (torch.linalg.norm(direction) + 1e-7) | |
| gridh = (relis-direction[1]) / (h-1) * 2 - 1 | |
| gridw = (reljs-direction[0]) / (w-1) * 2 - 1 | |
| grid = torch.stack( | |
| [gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0) | |
| target = F.grid_sample( | |
| feat_resize.float(), grid, align_corners=True).squeeze(2) | |
| loss_motion += F.l1_loss( | |
| feat_resize[:, :, relis, reljs], target.detach()) | |
| loss = loss_motion | |
| if mask is not None: | |
| if mask.min() == 0 and mask.max() == 1: | |
| mask_usq = mask.to( | |
| self._device).unsqueeze(0).unsqueeze(0) | |
| loss_fix = F.l1_loss( | |
| feat_resize * mask_usq, self.feat0_resize * mask_usq) | |
| loss += lambda_mask * loss_fix | |
| # latent code regularization | |
| loss += reg * F.l1_loss(ws, self.w0) | |
| if not res.stop: | |
| self.w_optim.zero_grad() | |
| loss.backward() | |
| self.w_optim.step() | |
| # Scale and convert to uint8. | |
| img = img[0] | |
| if img_normalize: | |
| img = img / img.norm(float('inf'), | |
| dim=[1, 2], keepdim=True).clip(1e-8, 1e8) | |
| img = img * (10 ** (img_scale_db / 20)) | |
| img = (img * 127.5 + 128).clamp(0, | |
| 255).to(torch.uint8).permute(1, 2, 0) | |
| if to_pil: | |
| from PIL import Image | |
| img = img.cpu().numpy() | |
| img = Image.fromarray(img) | |
| res.image = img | |
| res.w = ws.detach().cpu().numpy() | |
| except Exception as e: | |
| import os | |
| print(f'Renderer error: {e}') | |
| print("Out of memory error occurred. Restarting the app...") | |
| os.execv(sys.executable, ['python'] + sys.argv) | |
| # ---------------------------------------------------------------------------- | |