init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +30 -0
- app.py +267 -0
- models/__init__.py +63 -0
- models/__pycache__/__init__.cpython-37.pyc +0 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/bev3d_generator.cpython-37.pyc +0 -0
- models/__pycache__/bev3d_generator.cpython-39.pyc +0 -0
- models/__pycache__/eg3d_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/eg3d_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/eg3d_generator.cpython-37.pyc +0 -0
- models/__pycache__/eg3d_generator.cpython-39.pyc +0 -0
- models/__pycache__/eg3d_generator_fv.cpython-37.pyc +0 -0
- models/__pycache__/eg3d_generator_fv.cpython-39.pyc +0 -0
- models/__pycache__/ghfeat_encoder.cpython-37.pyc +0 -0
- models/__pycache__/ghfeat_encoder.cpython-39.pyc +0 -0
- models/__pycache__/inception_model.cpython-37.pyc +0 -0
- models/__pycache__/inception_model.cpython-39.pyc +0 -0
- models/__pycache__/perceptual_model.cpython-37.pyc +0 -0
- models/__pycache__/perceptual_model.cpython-39.pyc +0 -0
- models/__pycache__/pggan_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/pggan_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/pggan_generator.cpython-37.pyc +0 -0
- models/__pycache__/pggan_generator.cpython-39.pyc +0 -0
- models/__pycache__/pigan_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/pigan_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/pigan_generator.cpython-37.pyc +0 -0
- models/__pycache__/pigan_generator.cpython-39.pyc +0 -0
- models/__pycache__/sgbev3d_generator.cpython-37.pyc +0 -0
- models/__pycache__/sgbev3d_generator.cpython-39.pyc +0 -0
- models/__pycache__/stylegan2_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/stylegan2_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/stylegan2_generator.cpython-37.pyc +0 -0
- models/__pycache__/stylegan2_generator.cpython-39.pyc +0 -0
- models/__pycache__/stylegan3_generator.cpython-37.pyc +0 -0
- models/__pycache__/stylegan3_generator.cpython-39.pyc +0 -0
- models/__pycache__/stylegan_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/stylegan_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/stylegan_generator.cpython-37.pyc +0 -0
- models/__pycache__/stylegan_generator.cpython-39.pyc +0 -0
- models/__pycache__/volumegan_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/volumegan_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/volumegan_generator.cpython-37.pyc +0 -0
- models/__pycache__/volumegan_generator.cpython-39.pyc +0 -0
- models/bev3d_generator.py +301 -0
- models/eg3d_discriminator.py +243 -0
- models/eg3d_generator.py +315 -0
- models/eg3d_generator_fv.py +320 -0
- models/ghfeat_encoder.py +563 -0
- models/inception_model.py +562 -0
- models/perceptual_model.py +519 -0
Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:11.1.0-devel-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
ENV CUDA_HOME=/usr/local/cuda
|
| 4 |
+
ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH}
|
| 5 |
+
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
| 6 |
+
ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH}
|
| 7 |
+
|
| 8 |
+
# apt install by root user
|
| 9 |
+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
| 10 |
+
build-essential \
|
| 11 |
+
curl \
|
| 12 |
+
git \
|
| 13 |
+
python-is-python3 \
|
| 14 |
+
python3.7-dev \
|
| 15 |
+
python3-pip \
|
| 16 |
+
wget \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
RUN pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
WORKDIR /code
|
| 23 |
+
|
| 24 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 25 |
+
|
| 26 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 27 |
+
|
| 28 |
+
COPY . .
|
| 29 |
+
|
| 30 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from models import build_model
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torchvision
|
| 6 |
+
import ninja
|
| 7 |
+
import torch
|
| 8 |
+
from tqdm import trange
|
| 9 |
+
import imageio
|
| 10 |
+
|
| 11 |
+
checkpoint = '/mnt/petrelfs/zhangqihang/data/berfscene_clevr.pth'
|
| 12 |
+
state = torch.load(checkpoint, map_location='cpu')
|
| 13 |
+
G = build_model(**state['model_kwargs_init']['generator_smooth'])
|
| 14 |
+
o0, o1 = G.load_state_dict(state['models']['generator_smooth'], strict=False)
|
| 15 |
+
G.eval().cuda()
|
| 16 |
+
G.backbone.synthesis.input.x_offset =0
|
| 17 |
+
G.backbone.synthesis.input.y_offset =0
|
| 18 |
+
G_kwargs= dict(noise_mode='const',
|
| 19 |
+
fused_modulate=False,
|
| 20 |
+
impl='cuda',
|
| 21 |
+
fp16_res=None)
|
| 22 |
+
|
| 23 |
+
def trans(x, y, z, length):
|
| 24 |
+
w = h = length
|
| 25 |
+
x = 0.5 * w - 128 + 256 - (x/9 + .5) * 256
|
| 26 |
+
y = 0.5 * h - 128 + (y/9 + .5) * 256
|
| 27 |
+
z = z / 9 * 256
|
| 28 |
+
return x, y, z
|
| 29 |
+
def get_bev_from_objs(objs, length=256, scale = 6):
|
| 30 |
+
h, w = length, length *scale
|
| 31 |
+
nc = 14
|
| 32 |
+
canvas = np.zeros([h, w, nc])
|
| 33 |
+
xx = np.ones([h,w]).cumsum(0)
|
| 34 |
+
yy = np.ones([h,w]).cumsum(1)
|
| 35 |
+
|
| 36 |
+
for x, y, z, shape, color, material, rot in objs:
|
| 37 |
+
y, x, z = trans(x, y, z, length)
|
| 38 |
+
|
| 39 |
+
feat = [0] * nc
|
| 40 |
+
feat[0] = 1
|
| 41 |
+
feat[COLOR_NAME_LIST.index(color) + 1] = 1
|
| 42 |
+
feat[SHAPE_NAME_LIST.index(shape) + 1 + len(COLOR_NAME_LIST)] = 1
|
| 43 |
+
feat[MATERIAL_NAME_LIST.index(material) + 1 + len(COLOR_NAME_LIST) + len(SHAPE_NAME_LIST)] = 1
|
| 44 |
+
feat = np.array(feat)
|
| 45 |
+
rot_sin = np.sin(rot / 180 * np.pi)
|
| 46 |
+
rot_cos = np.cos(rot / 180 * np.pi)
|
| 47 |
+
|
| 48 |
+
if shape == 'cube':
|
| 49 |
+
mask = (np.abs(+rot_cos * (xx-x) + rot_sin * (yy-y)) <= z) * \
|
| 50 |
+
(np.abs(-rot_sin * (xx-x) + rot_cos * (yy-y)) <= z)
|
| 51 |
+
else:
|
| 52 |
+
mask = ((xx-x)**2 + (y-yy)**2) ** 0.5 <= z
|
| 53 |
+
canvas[mask] = feat
|
| 54 |
+
canvas = np.transpose(canvas, [2, 0, 1]).astype(np.float32)
|
| 55 |
+
rotate_angle = 0
|
| 56 |
+
canvas = torchvision.transforms.functional.rotate(torch.tensor(canvas), rotate_angle).numpy()
|
| 57 |
+
return canvas
|
| 58 |
+
|
| 59 |
+
# COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'brown', 'blue']
|
| 60 |
+
COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'purple', 'blue']
|
| 61 |
+
SHAPE_NAME_LIST = ['cube', 'sphere', 'cylinder']
|
| 62 |
+
MATERIAL_NAME_LIST = ['rubber', 'metal']
|
| 63 |
+
|
| 64 |
+
xy_lib = dict()
|
| 65 |
+
xy_lib['B'] = [
|
| 66 |
+
[-2, -1],
|
| 67 |
+
[-1, -1],
|
| 68 |
+
[-2, 0],
|
| 69 |
+
[-2, 1],
|
| 70 |
+
[-1, .5],
|
| 71 |
+
[0, 1],
|
| 72 |
+
[0, 0],
|
| 73 |
+
[0, -1],
|
| 74 |
+
[0, 2],
|
| 75 |
+
[-1, 2],
|
| 76 |
+
[-2, 2]
|
| 77 |
+
]
|
| 78 |
+
xy_lib['B'] = [
|
| 79 |
+
[-2.5, 1.25],
|
| 80 |
+
[-2, 2],
|
| 81 |
+
[-2, 0.5],
|
| 82 |
+
[-2, -0.75],
|
| 83 |
+
[-1, -1],
|
| 84 |
+
[-1, 2],
|
| 85 |
+
[-1, 0],
|
| 86 |
+
[-1, 2],
|
| 87 |
+
[0, 1],
|
| 88 |
+
[0, 0],
|
| 89 |
+
[0, -1],
|
| 90 |
+
[0, 2],
|
| 91 |
+
# [-1, 2],
|
| 92 |
+
|
| 93 |
+
]
|
| 94 |
+
xy_lib['B'] = [
|
| 95 |
+
[-2.5, 1.25],
|
| 96 |
+
[-2, 2],
|
| 97 |
+
[-2, 0.5],
|
| 98 |
+
[-2, -1],
|
| 99 |
+
[-1, -1.25],
|
| 100 |
+
[-1, 2],
|
| 101 |
+
[-1, 0],
|
| 102 |
+
[-1, 2],
|
| 103 |
+
[0, 1],
|
| 104 |
+
[0, 0],
|
| 105 |
+
[0, -1.25],
|
| 106 |
+
[0, 2],
|
| 107 |
+
# [-1, 2],
|
| 108 |
+
|
| 109 |
+
]
|
| 110 |
+
xy_lib['R'] = [
|
| 111 |
+
[0, -1],
|
| 112 |
+
[0, 0],
|
| 113 |
+
[0, 1],
|
| 114 |
+
[0, 2],
|
| 115 |
+
[-1, -1],
|
| 116 |
+
# [-1, 2],
|
| 117 |
+
[-2, -1],
|
| 118 |
+
[-2, 0],
|
| 119 |
+
[-2.25, 2],
|
| 120 |
+
[-1, 1]
|
| 121 |
+
]
|
| 122 |
+
xy_lib['C'] = [
|
| 123 |
+
[0, -1],
|
| 124 |
+
[0, 0],
|
| 125 |
+
[0, 1],
|
| 126 |
+
[0, 2],
|
| 127 |
+
[-1, -1],
|
| 128 |
+
[-1, 2],
|
| 129 |
+
[-2, -1],
|
| 130 |
+
# [-2, .5],
|
| 131 |
+
[-2, 2],
|
| 132 |
+
# [-1, .5]
|
| 133 |
+
]
|
| 134 |
+
xy_lib['s'] = [
|
| 135 |
+
[0, -1],
|
| 136 |
+
[0, 0],
|
| 137 |
+
[0, 2],
|
| 138 |
+
[-1, -1],
|
| 139 |
+
[-1, 2],
|
| 140 |
+
[-2, -1],
|
| 141 |
+
[-2, 1],
|
| 142 |
+
[-2, 2],
|
| 143 |
+
[-1, .5]
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
xy_lib['F'] = [
|
| 147 |
+
[0, -1],
|
| 148 |
+
[0, 0],
|
| 149 |
+
[0, 1],
|
| 150 |
+
[0, 2],
|
| 151 |
+
[-1, -1],
|
| 152 |
+
# [-1, 2],
|
| 153 |
+
[-2, -1],
|
| 154 |
+
[-2, .5],
|
| 155 |
+
# [-2, 2],
|
| 156 |
+
[-1, .5]
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
xy_lib['c'] = [
|
| 160 |
+
[0.8,1],
|
| 161 |
+
# [-0.8,1],
|
| 162 |
+
[0,0.1],
|
| 163 |
+
[0,1.9],
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
xy_lib['e'] = [
|
| 167 |
+
[0, -1],
|
| 168 |
+
[0, 0],
|
| 169 |
+
[0, 1],
|
| 170 |
+
[0, 2],
|
| 171 |
+
[-1, -1],
|
| 172 |
+
[-1, 2],
|
| 173 |
+
[-2, -1],
|
| 174 |
+
[-2, .5],
|
| 175 |
+
[-2, 2],
|
| 176 |
+
[-1, .5]
|
| 177 |
+
]
|
| 178 |
+
xy_lib['n'] = [
|
| 179 |
+
[0,1],
|
| 180 |
+
[0,-1],
|
| 181 |
+
[0,0.1],
|
| 182 |
+
[0,1.9],
|
| 183 |
+
[-1,0],
|
| 184 |
+
[-2,1],
|
| 185 |
+
[-3,-1],
|
| 186 |
+
[-3,1],
|
| 187 |
+
[-3,0.1],
|
| 188 |
+
[-3,1.9],
|
| 189 |
+
]
|
| 190 |
+
offset_x = dict(B=4, R=4, C=4, F=4, c=3, s=4, e=4, n=4.8)
|
| 191 |
+
s = 'BeRFsCene'
|
| 192 |
+
objs = []
|
| 193 |
+
offset = 2
|
| 194 |
+
for idx, c in enumerate(s):
|
| 195 |
+
xy = xy_lib[c]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
color = np.random.choice(COLOR_NAME_LIST)
|
| 199 |
+
for i in range(len(xy)):
|
| 200 |
+
# while 1:
|
| 201 |
+
# is_ok = 1
|
| 202 |
+
# x, y =
|
| 203 |
+
|
| 204 |
+
# for prev_x, prev_y in zip(xpool, ypool):
|
| 205 |
+
x, y = xy[i]
|
| 206 |
+
y *= 1.5
|
| 207 |
+
y -= 0.5
|
| 208 |
+
x -= offset
|
| 209 |
+
z = 0.35
|
| 210 |
+
# if idx<4:
|
| 211 |
+
# color = np.random.choice(COLOR_NAME_LIST[:-1])
|
| 212 |
+
# else:
|
| 213 |
+
# color = 'blue'
|
| 214 |
+
shape = 'cube'
|
| 215 |
+
material = 'rubber'
|
| 216 |
+
rot = 0
|
| 217 |
+
objs.append([x, y, z, shape, color, material, rot])
|
| 218 |
+
offset += offset_x[c]
|
| 219 |
+
Image.fromarray((255 * .8 - get_bev_from_objs(objs)[0] *.8 * 255).astype(np.uint8))
|
| 220 |
+
|
| 221 |
+
batch_size = 1
|
| 222 |
+
code = torch.randn(1, G.z_dim).cuda()
|
| 223 |
+
to_pil = torchvision.transforms.ToPILImage()
|
| 224 |
+
large_bevs = torch.tensor(get_bev_from_objs(objs)).cuda()[None]
|
| 225 |
+
bevs = large_bevs[..., 0: 0+256]
|
| 226 |
+
RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, -0.8660,
|
| 227 |
+
10.3923, 0.0000, -0.8660, -0.5000, 6.0000, 0.0000, 0.0000,
|
| 228 |
+
0.0000, 1.0000, 262.5000, 0.0000, 32.0000, 0.0000, 262.5000,
|
| 229 |
+
32.0000, 0.0000, 0.0000, 1.0000]], device='cuda')
|
| 230 |
+
|
| 231 |
+
print('prepare finish', flush=True)
|
| 232 |
+
|
| 233 |
+
def inference(name):
|
| 234 |
+
print('inference', name, flush=True)
|
| 235 |
+
gen = G(code, RT, bevs)
|
| 236 |
+
rgb = gen['gen_output']['image'][0] * .5 + .5
|
| 237 |
+
print('inference', name, flush=True)
|
| 238 |
+
return np.array(to_pil(rgb))
|
| 239 |
+
|
| 240 |
+
# to_pil(rgb).save('tmp.png')
|
| 241 |
+
# save_path = '/mnt/petrelfs/zhangqihang/code/3d-scene-gen/tmp.png'
|
| 242 |
+
# return [save_path]
|
| 243 |
+
|
| 244 |
+
with gr.Blocks() as demo:
|
| 245 |
+
gr.HTML(
|
| 246 |
+
"""
|
| 247 |
+
abc
|
| 248 |
+
""")
|
| 249 |
+
|
| 250 |
+
with gr.Group():
|
| 251 |
+
with gr.Row():
|
| 252 |
+
with gr.Column():
|
| 253 |
+
with gr.Row():
|
| 254 |
+
with gr.Column():
|
| 255 |
+
with gr.Row():
|
| 256 |
+
num_frames = gr.Dropdown(["24 - frames", "32 - frames", "40 - frames", "48 - frames", "56 - frames", "80 - recommended to run on local GPUs", "240 - recommended to run on local GPUs", "600 - recommended to run on local GPUs", "1200 - recommended to run on local GPUs", "10000 - recommended to run on local GPUs"], label="Number of Video Frames", info="For >56 frames use local workstation!", value="24 - frames")
|
| 257 |
+
|
| 258 |
+
with gr.Row():
|
| 259 |
+
with gr.Row():
|
| 260 |
+
btn = gr.Button("Result")
|
| 261 |
+
|
| 262 |
+
gallery = gr.Image(label='img', show_label=True, elem_id="gallery")
|
| 263 |
+
|
| 264 |
+
btn.click(fn=inference, inputs=num_frames, outputs=[gallery], postprocess=False)
|
| 265 |
+
|
| 266 |
+
demo.queue()
|
| 267 |
+
demo.launch(server_name='0.0.0.0', server_port=10093, debug=True, show_error=True)
|
models/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Collects all models."""
|
| 3 |
+
|
| 4 |
+
from .pggan_generator import PGGANGenerator
|
| 5 |
+
from .pggan_discriminator import PGGANDiscriminator
|
| 6 |
+
from .stylegan_generator import StyleGANGenerator
|
| 7 |
+
from .stylegan_discriminator import StyleGANDiscriminator
|
| 8 |
+
from .stylegan2_generator import StyleGAN2Generator
|
| 9 |
+
from .stylegan2_discriminator import StyleGAN2Discriminator
|
| 10 |
+
from .stylegan3_generator import StyleGAN3Generator
|
| 11 |
+
from .ghfeat_encoder import GHFeatEncoder
|
| 12 |
+
from .perceptual_model import PerceptualModel
|
| 13 |
+
from .inception_model import InceptionModel
|
| 14 |
+
from .eg3d_generator import EG3DGenerator
|
| 15 |
+
from .eg3d_discriminator import DualDiscriminator
|
| 16 |
+
from .pigan_generator import PiGANGenerator
|
| 17 |
+
from .pigan_discriminator import PiGANDiscriminator
|
| 18 |
+
from .volumegan_generator import VolumeGANGenerator
|
| 19 |
+
from .volumegan_discriminator import VolumeGANDiscriminator
|
| 20 |
+
from .eg3d_generator_fv import EG3DGeneratorFV
|
| 21 |
+
from .bev3d_generator import BEV3DGenerator
|
| 22 |
+
from .sgbev3d_generator import SGBEV3DGenerator
|
| 23 |
+
|
| 24 |
+
__all__ = ['build_model']
|
| 25 |
+
|
| 26 |
+
_MODELS = {
|
| 27 |
+
'PGGANGenerator': PGGANGenerator,
|
| 28 |
+
'PGGANDiscriminator': PGGANDiscriminator,
|
| 29 |
+
'StyleGANGenerator': StyleGANGenerator,
|
| 30 |
+
'StyleGANDiscriminator': StyleGANDiscriminator,
|
| 31 |
+
'StyleGAN2Generator': StyleGAN2Generator,
|
| 32 |
+
'StyleGAN2Discriminator': StyleGAN2Discriminator,
|
| 33 |
+
'StyleGAN3Generator': StyleGAN3Generator,
|
| 34 |
+
'GHFeatEncoder': GHFeatEncoder,
|
| 35 |
+
'PerceptualModel': PerceptualModel.build_model,
|
| 36 |
+
'InceptionModel': InceptionModel.build_model,
|
| 37 |
+
'EG3DGenerator': EG3DGenerator,
|
| 38 |
+
'EG3DDiscriminator': DualDiscriminator,
|
| 39 |
+
'PiGANGenerator': PiGANGenerator,
|
| 40 |
+
'PiGANDiscriminator': PiGANDiscriminator,
|
| 41 |
+
'VolumeGANGenerator': VolumeGANGenerator,
|
| 42 |
+
'VolumeGANDiscriminator': VolumeGANDiscriminator,
|
| 43 |
+
'EG3DGeneratorFV': EG3DGeneratorFV,
|
| 44 |
+
'BEV3DGenerator': BEV3DGenerator,
|
| 45 |
+
'SGBEV3DGenerator': SGBEV3DGenerator,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def build_model(model_type, **kwargs):
|
| 50 |
+
"""Builds a model based on its class type.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
model_type: Class type to which the model belongs, which is case
|
| 54 |
+
sensitive.
|
| 55 |
+
**kwargs: Additional arguments to build the model.
|
| 56 |
+
|
| 57 |
+
Raises:
|
| 58 |
+
ValueError: If the `model_type` is not supported.
|
| 59 |
+
"""
|
| 60 |
+
if model_type not in _MODELS:
|
| 61 |
+
raise ValueError(f'Invalid model type: `{model_type}`!\n'
|
| 62 |
+
f'Types allowed: {list(_MODELS)}.')
|
| 63 |
+
return _MODELS[model_type](**kwargs)
|
models/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (2.06 kB). View file
|
|
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
models/__pycache__/bev3d_generator.cpython-37.pyc
ADDED
|
Binary file (6.16 kB). View file
|
|
|
models/__pycache__/bev3d_generator.cpython-39.pyc
ADDED
|
Binary file (6.07 kB). View file
|
|
|
models/__pycache__/eg3d_discriminator.cpython-37.pyc
ADDED
|
Binary file (8.01 kB). View file
|
|
|
models/__pycache__/eg3d_discriminator.cpython-39.pyc
ADDED
|
Binary file (7.73 kB). View file
|
|
|
models/__pycache__/eg3d_generator.cpython-37.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
models/__pycache__/eg3d_generator.cpython-39.pyc
ADDED
|
Binary file (6.3 kB). View file
|
|
|
models/__pycache__/eg3d_generator_fv.cpython-37.pyc
ADDED
|
Binary file (6.35 kB). View file
|
|
|
models/__pycache__/eg3d_generator_fv.cpython-39.pyc
ADDED
|
Binary file (6.43 kB). View file
|
|
|
models/__pycache__/ghfeat_encoder.cpython-37.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
models/__pycache__/ghfeat_encoder.cpython-39.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
models/__pycache__/inception_model.cpython-37.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
models/__pycache__/inception_model.cpython-39.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
models/__pycache__/perceptual_model.cpython-37.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
models/__pycache__/perceptual_model.cpython-39.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
models/__pycache__/pggan_discriminator.cpython-37.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
models/__pycache__/pggan_discriminator.cpython-39.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
models/__pycache__/pggan_generator.cpython-37.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
models/__pycache__/pggan_generator.cpython-39.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
models/__pycache__/pigan_discriminator.cpython-37.pyc
ADDED
|
Binary file (8.32 kB). View file
|
|
|
models/__pycache__/pigan_discriminator.cpython-39.pyc
ADDED
|
Binary file (8.31 kB). View file
|
|
|
models/__pycache__/pigan_generator.cpython-37.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
models/__pycache__/pigan_generator.cpython-39.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
models/__pycache__/sgbev3d_generator.cpython-37.pyc
ADDED
|
Binary file (7.01 kB). View file
|
|
|
models/__pycache__/sgbev3d_generator.cpython-39.pyc
ADDED
|
Binary file (7.04 kB). View file
|
|
|
models/__pycache__/stylegan2_discriminator.cpython-37.pyc
ADDED
|
Binary file (17.7 kB). View file
|
|
|
models/__pycache__/stylegan2_discriminator.cpython-39.pyc
ADDED
|
Binary file (17.7 kB). View file
|
|
|
models/__pycache__/stylegan2_generator.cpython-37.pyc
ADDED
|
Binary file (32.9 kB). View file
|
|
|
models/__pycache__/stylegan2_generator.cpython-39.pyc
ADDED
|
Binary file (32.9 kB). View file
|
|
|
models/__pycache__/stylegan3_generator.cpython-37.pyc
ADDED
|
Binary file (35.8 kB). View file
|
|
|
models/__pycache__/stylegan3_generator.cpython-39.pyc
ADDED
|
Binary file (35.7 kB). View file
|
|
|
models/__pycache__/stylegan_discriminator.cpython-37.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
models/__pycache__/stylegan_discriminator.cpython-39.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
models/__pycache__/stylegan_generator.cpython-37.pyc
ADDED
|
Binary file (24.9 kB). View file
|
|
|
models/__pycache__/stylegan_generator.cpython-39.pyc
ADDED
|
Binary file (24.9 kB). View file
|
|
|
models/__pycache__/volumegan_discriminator.cpython-37.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
models/__pycache__/volumegan_discriminator.cpython-39.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
models/__pycache__/volumegan_generator.cpython-37.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
models/__pycache__/volumegan_generator.cpython-39.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
models/bev3d_generator.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.8
|
| 2 |
+
"""Contains the implementation of generator described in BEV3D."""
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone
|
| 7 |
+
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
|
| 8 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid2X
|
| 9 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid4X
|
| 10 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid4X_conststyle
|
| 11 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
|
| 12 |
+
from models.rendering.renderer import Renderer
|
| 13 |
+
from models.rendering.feature_extractor import FeatureExtractor
|
| 14 |
+
|
| 15 |
+
from models.utils.spade import SPADEGenerator
|
| 16 |
+
|
| 17 |
+
class BEV3DGenerator(nn.Module):
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
z_dim,
|
| 22 |
+
semantic_nc,
|
| 23 |
+
ngf,
|
| 24 |
+
bev_grid_size,
|
| 25 |
+
aspect_ratio,
|
| 26 |
+
num_upsampling_layers,
|
| 27 |
+
not_use_vae,
|
| 28 |
+
norm_G,
|
| 29 |
+
img_resolution,
|
| 30 |
+
interpolate_sr,
|
| 31 |
+
segmask=False,
|
| 32 |
+
dim_seq='16,8,4,2,1',
|
| 33 |
+
xyz_pe=False,
|
| 34 |
+
hidden_dim=64,
|
| 35 |
+
additional_layer_num=0,
|
| 36 |
+
sr_num_fp16_res=0, # Number of fp16 layers of SR Network.
|
| 37 |
+
rendering_kwargs={}, # Arguments for rendering.
|
| 38 |
+
sr_kwargs={}, # Arguments for SuperResolution Network.
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.z_dim = z_dim
|
| 43 |
+
self.interpolate_sr = interpolate_sr
|
| 44 |
+
self.segmask = segmask
|
| 45 |
+
|
| 46 |
+
# Set up the overall renderer.
|
| 47 |
+
self.renderer = Renderer()
|
| 48 |
+
|
| 49 |
+
# Set up the feature extractor.
|
| 50 |
+
self.feature_extractor = FeatureExtractor(ref_mode='bev_plane_clevr', xyz_pe=xyz_pe)
|
| 51 |
+
|
| 52 |
+
# Set up the reference representation generator.
|
| 53 |
+
self.backbone = SPADEGenerator(z_dim=z_dim, semantic_nc=semantic_nc, ngf=ngf, dim_seq=dim_seq, bev_grid_size=bev_grid_size,
|
| 54 |
+
aspect_ratio=aspect_ratio, num_upsampling_layers=num_upsampling_layers,
|
| 55 |
+
not_use_vae=not_use_vae, norm_G=norm_G)
|
| 56 |
+
print('backbone SPADEGenerator set up!')
|
| 57 |
+
|
| 58 |
+
# Set up the post module in the feature extractor.
|
| 59 |
+
self.post_module = None
|
| 60 |
+
|
| 61 |
+
# Set up the post neural renderer.
|
| 62 |
+
self.post_neural_renderer = None
|
| 63 |
+
sr_kwargs_total = dict(
|
| 64 |
+
channels=32,
|
| 65 |
+
img_resolution=img_resolution,
|
| 66 |
+
sr_num_fp16_res=sr_num_fp16_res,
|
| 67 |
+
sr_antialias=rendering_kwargs['sr_antialias'],)
|
| 68 |
+
sr_kwargs_total.update(**sr_kwargs)
|
| 69 |
+
if img_resolution == 128:
|
| 70 |
+
self.post_neural_renderer = SuperresolutionHybrid2X(
|
| 71 |
+
**sr_kwargs_total)
|
| 72 |
+
elif img_resolution == 256:
|
| 73 |
+
self.post_neural_renderer = SuperresolutionHybrid4X_conststyle(
|
| 74 |
+
**sr_kwargs_total)
|
| 75 |
+
elif img_resolution == 512:
|
| 76 |
+
self.post_neural_renderer = SuperresolutionHybrid8XDC(
|
| 77 |
+
**sr_kwargs_total)
|
| 78 |
+
else:
|
| 79 |
+
raise TypeError(f'Unsupported image resolution: {img_resolution}!')
|
| 80 |
+
|
| 81 |
+
# Set up the fully-connected layer head.
|
| 82 |
+
self.fc_head = OSGDecoder(
|
| 83 |
+
128 if xyz_pe else 64 , {
|
| 84 |
+
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
|
| 85 |
+
'decoder_output_dim': 32
|
| 86 |
+
},
|
| 87 |
+
hidden_dim=hidden_dim,
|
| 88 |
+
additional_layer_num=additional_layer_num
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Set up some rendering related arguments.
|
| 92 |
+
self.neural_rendering_resolution = rendering_kwargs.get(
|
| 93 |
+
'resolution', 64)
|
| 94 |
+
self.rendering_kwargs = rendering_kwargs
|
| 95 |
+
|
| 96 |
+
def synthesis(self,
|
| 97 |
+
z,
|
| 98 |
+
c,
|
| 99 |
+
seg,
|
| 100 |
+
neural_rendering_resolution=None,
|
| 101 |
+
update_emas=False,
|
| 102 |
+
**synthesis_kwargs):
|
| 103 |
+
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
| 104 |
+
if self.rendering_kwargs.get('random_pose', False):
|
| 105 |
+
cam2world_matrix = None
|
| 106 |
+
|
| 107 |
+
if neural_rendering_resolution is None:
|
| 108 |
+
neural_rendering_resolution = self.neural_rendering_resolution
|
| 109 |
+
else:
|
| 110 |
+
self.neural_rendering_resolution = neural_rendering_resolution
|
| 111 |
+
|
| 112 |
+
xy_planes = self.backbone(z=z, input=seg)
|
| 113 |
+
if self.segmask:
|
| 114 |
+
xy_planes = xy_planes * seg[:, 0, ...][:, None, ...]
|
| 115 |
+
|
| 116 |
+
# import pdb;pdb.set_trace()
|
| 117 |
+
|
| 118 |
+
wp = z # in our case, we do not use wp.
|
| 119 |
+
|
| 120 |
+
rendering_result = self.renderer(
|
| 121 |
+
wp=wp,
|
| 122 |
+
feature_extractor=self.feature_extractor,
|
| 123 |
+
rendering_options=self.rendering_kwargs,
|
| 124 |
+
cam2world_matrix=cam2world_matrix,
|
| 125 |
+
position_encoder=None,
|
| 126 |
+
ref_representation=xy_planes,
|
| 127 |
+
post_module=self.post_module,
|
| 128 |
+
fc_head=self.fc_head)
|
| 129 |
+
|
| 130 |
+
feature_samples = rendering_result['composite_rgb']
|
| 131 |
+
depth_samples = rendering_result['composite_depth']
|
| 132 |
+
|
| 133 |
+
# Reshape to keep consistent with 'raw' neural-rendered image.
|
| 134 |
+
N = wp.shape[0]
|
| 135 |
+
H = W = self.neural_rendering_resolution
|
| 136 |
+
feature_image = feature_samples.permute(0, 2, 1).reshape(
|
| 137 |
+
N, feature_samples.shape[-1], H, W).contiguous()
|
| 138 |
+
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
|
| 139 |
+
|
| 140 |
+
# Run the post neural renderer to get final image.
|
| 141 |
+
# Here, the post neural renderer is a super-resolution network.
|
| 142 |
+
rgb_image = feature_image[:, :3]
|
| 143 |
+
if self.interpolate_sr:
|
| 144 |
+
sr_image = torch.nn.functional.interpolate(rgb_image, size=(256, 256), mode='bilinear', align_corners=False)
|
| 145 |
+
else:
|
| 146 |
+
sr_image = self.post_neural_renderer(
|
| 147 |
+
rgb_image,
|
| 148 |
+
feature_image,
|
| 149 |
+
# wp,
|
| 150 |
+
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
|
| 151 |
+
**{
|
| 152 |
+
k: synthesis_kwargs[k]
|
| 153 |
+
for k in synthesis_kwargs.keys() if k != 'noise_mode'
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
return {
|
| 157 |
+
'image': sr_image,
|
| 158 |
+
'image_raw': rgb_image,
|
| 159 |
+
'image_depth': depth_image
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def sample(self,
|
| 163 |
+
coordinates,
|
| 164 |
+
directions,
|
| 165 |
+
z,
|
| 166 |
+
c,
|
| 167 |
+
seg,
|
| 168 |
+
truncation_psi=1,
|
| 169 |
+
truncation_cutoff=None,
|
| 170 |
+
update_emas=False,
|
| 171 |
+
**synthesis_kwargs):
|
| 172 |
+
# Compute RGB features, density for arbitrary 3D coordinates.
|
| 173 |
+
# Mostly used for extracting shapes.
|
| 174 |
+
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
| 175 |
+
xy_planes = self.backbone(z=z, input=seg)
|
| 176 |
+
wp = z
|
| 177 |
+
result = self.renderer.get_sigma_rgb(
|
| 178 |
+
wp=wp,
|
| 179 |
+
points=coordinates,
|
| 180 |
+
feature_extractor=self.feature_extractor,
|
| 181 |
+
fc_head=self.fc_head,
|
| 182 |
+
rendering_options=self.rendering_kwargs,
|
| 183 |
+
ref_representation=xy_planes,
|
| 184 |
+
post_module=self.post_module,
|
| 185 |
+
ray_dirs=directions,
|
| 186 |
+
cam_matrix=cam2world_matrix)
|
| 187 |
+
|
| 188 |
+
return result
|
| 189 |
+
|
| 190 |
+
def sample_mixed(self,
|
| 191 |
+
coordinates,
|
| 192 |
+
directions,
|
| 193 |
+
z, c, seg,
|
| 194 |
+
truncation_psi=1,
|
| 195 |
+
truncation_cutoff=None,
|
| 196 |
+
update_emas=False,
|
| 197 |
+
**synthesis_kwargs):
|
| 198 |
+
# Same as function `self.sample()`, but expects latent vectors 'wp'
|
| 199 |
+
# instead of Gaussian noise 'z'.
|
| 200 |
+
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
| 201 |
+
xy_planes = self.backbone(z=z, input=seg)
|
| 202 |
+
wp = z
|
| 203 |
+
result = self.renderer.get_sigma_rgb(
|
| 204 |
+
wp=wp,
|
| 205 |
+
points=coordinates,
|
| 206 |
+
feature_extractor=self.feature_extractor,
|
| 207 |
+
fc_head=self.fc_head,
|
| 208 |
+
rendering_options=self.rendering_kwargs,
|
| 209 |
+
ref_representation=xy_planes,
|
| 210 |
+
post_module=self.post_module,
|
| 211 |
+
ray_dirs=directions,
|
| 212 |
+
cam_matrix=cam2world_matrix)
|
| 213 |
+
|
| 214 |
+
return result
|
| 215 |
+
|
| 216 |
+
def forward(self,
|
| 217 |
+
z,
|
| 218 |
+
c,
|
| 219 |
+
seg,
|
| 220 |
+
c_swapped=None, # `c_swapped` is swapped pose conditioning.
|
| 221 |
+
style_mixing_prob=0,
|
| 222 |
+
truncation_psi=1,
|
| 223 |
+
truncation_cutoff=None,
|
| 224 |
+
neural_rendering_resolution=None,
|
| 225 |
+
update_emas=False,
|
| 226 |
+
sample_mixed=False,
|
| 227 |
+
coordinates=None,
|
| 228 |
+
**synthesis_kwargs):
|
| 229 |
+
|
| 230 |
+
# Render a batch of generated images.
|
| 231 |
+
c_wp = c.clone()
|
| 232 |
+
if c_swapped is not None:
|
| 233 |
+
c_wp = c_swapped.clone()
|
| 234 |
+
|
| 235 |
+
if not sample_mixed:
|
| 236 |
+
gen_output = self.synthesis(
|
| 237 |
+
z,
|
| 238 |
+
c,
|
| 239 |
+
seg,
|
| 240 |
+
update_emas=update_emas,
|
| 241 |
+
neural_rendering_resolution=neural_rendering_resolution,
|
| 242 |
+
**synthesis_kwargs)
|
| 243 |
+
|
| 244 |
+
return {
|
| 245 |
+
'wp': z,
|
| 246 |
+
'gen_output': gen_output,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
else:
|
| 250 |
+
# Only for density regularization in training process.
|
| 251 |
+
assert coordinates is not None
|
| 252 |
+
sample_sigma = self.sample_mixed(coordinates,
|
| 253 |
+
torch.randn_like(coordinates),
|
| 254 |
+
z, c, seg,
|
| 255 |
+
update_emas=False)['sigma']
|
| 256 |
+
|
| 257 |
+
return {
|
| 258 |
+
'wp': z,
|
| 259 |
+
'sample_sigma': sample_sigma
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class OSGDecoder(nn.Module):
|
| 264 |
+
"""Defines fully-connected layer head in EG3D."""
|
| 265 |
+
def __init__(self, n_features, options, hidden_dim=64, additional_layer_num=0):
|
| 266 |
+
super().__init__()
|
| 267 |
+
self.hidden_dim = hidden_dim
|
| 268 |
+
|
| 269 |
+
lst = []
|
| 270 |
+
lst.append(FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']))
|
| 271 |
+
lst.append(nn.Softplus())
|
| 272 |
+
for i in range(additional_layer_num):
|
| 273 |
+
lst.append(FullyConnectedLayer(self.hidden_dim, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']))
|
| 274 |
+
lst.append(nn.Softplus())
|
| 275 |
+
lst.append(FullyConnectedLayer(self.hidden_dim, 1+options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul']))
|
| 276 |
+
self.net = nn.Sequential(*lst)
|
| 277 |
+
|
| 278 |
+
# self.net = nn.Sequential(
|
| 279 |
+
# FullyConnectedLayer(n_features,
|
| 280 |
+
# self.hidden_dim,
|
| 281 |
+
# lr_multiplier=options['decoder_lr_mul']),
|
| 282 |
+
# nn.Softplus(),
|
| 283 |
+
# FullyConnectedLayer(self.hidden_dim,
|
| 284 |
+
# 1 + options['decoder_output_dim'],
|
| 285 |
+
# lr_multiplier=options['decoder_lr_mul']))
|
| 286 |
+
|
| 287 |
+
def forward(self, point_features, wp=None, dirs=None):
|
| 288 |
+
# Aggregate features
|
| 289 |
+
# point_features.shape: [N, R, K, C].
|
| 290 |
+
# Average across 'X, Y, Z' planes.
|
| 291 |
+
|
| 292 |
+
N, R, K, C = point_features.shape
|
| 293 |
+
x = point_features.reshape(-1, point_features.shape[-1])
|
| 294 |
+
x = self.net(x)
|
| 295 |
+
x = x.view(N, -1, x.shape[-1])
|
| 296 |
+
|
| 297 |
+
# Uses sigmoid clamping from MipNeRF
|
| 298 |
+
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
|
| 299 |
+
sigma = x[..., 0:1]
|
| 300 |
+
|
| 301 |
+
return {'rgb': rgb, 'sigma': sigma}
|
models/eg3d_discriminator.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python 3.7
|
| 2 |
+
"""Contains the implementation of discriminator described in EG3D."""
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from third_party.stylegan2_official_ops import upfirdn2d
|
| 8 |
+
from models.utils.official_stylegan2_model_helper import DiscriminatorBlock
|
| 9 |
+
from models.utils.official_stylegan2_model_helper import MappingNetwork
|
| 10 |
+
from models.utils.official_stylegan2_model_helper import DiscriminatorEpilogue
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SingleDiscriminator(torch.nn.Module):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 16 |
+
img_resolution, # Input resolution.
|
| 17 |
+
img_channels, # Number of input color channels.
|
| 18 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
| 19 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
| 20 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
| 21 |
+
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
|
| 22 |
+
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 23 |
+
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
|
| 24 |
+
sr_upsample_factor = 1, # Ignored for SingleDiscriminator
|
| 25 |
+
block_kwargs = {}, # Arguments for DiscriminatorBlock.
|
| 26 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
| 27 |
+
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.c_dim = c_dim
|
| 31 |
+
self.img_resolution = img_resolution
|
| 32 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
| 33 |
+
self.img_channels = img_channels
|
| 34 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
| 35 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
| 36 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
| 37 |
+
|
| 38 |
+
if cmap_dim is None:
|
| 39 |
+
cmap_dim = channels_dict[4]
|
| 40 |
+
if c_dim == 0:
|
| 41 |
+
cmap_dim = 0
|
| 42 |
+
|
| 43 |
+
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
| 44 |
+
cur_layer_idx = 0
|
| 45 |
+
for res in self.block_resolutions:
|
| 46 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
| 47 |
+
tmp_channels = channels_dict[res]
|
| 48 |
+
out_channels = channels_dict[res // 2]
|
| 49 |
+
use_fp16 = (res >= fp16_resolution)
|
| 50 |
+
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
| 51 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
| 52 |
+
setattr(self, f'b{res}', block)
|
| 53 |
+
cur_layer_idx += block.num_layers
|
| 54 |
+
if c_dim > 0:
|
| 55 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
|
| 56 |
+
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
|
| 57 |
+
|
| 58 |
+
def forward(self, img, c, update_emas=False, **block_kwargs):
|
| 59 |
+
img = img['image']
|
| 60 |
+
|
| 61 |
+
_ = update_emas # unused
|
| 62 |
+
x = None
|
| 63 |
+
for res in self.block_resolutions:
|
| 64 |
+
block = getattr(self, f'b{res}')
|
| 65 |
+
x, img = block(x, img, **block_kwargs)
|
| 66 |
+
|
| 67 |
+
cmap = None
|
| 68 |
+
if self.c_dim > 0:
|
| 69 |
+
cmap = self.mapping(None, c)
|
| 70 |
+
x = self.b4(x, img, cmap)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
def extra_repr(self):
|
| 74 |
+
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
|
| 75 |
+
|
| 76 |
+
#----------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'):
|
| 79 |
+
if filter_mode == 'antialiased':
|
| 80 |
+
ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
|
| 81 |
+
elif filter_mode == 'classic':
|
| 82 |
+
ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2)
|
| 83 |
+
ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False)
|
| 84 |
+
ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1)
|
| 85 |
+
elif filter_mode == 'none':
|
| 86 |
+
ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
|
| 87 |
+
elif type(filter_mode) == float:
|
| 88 |
+
assert 0 < filter_mode < 1
|
| 89 |
+
|
| 90 |
+
filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
|
| 91 |
+
aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
|
| 92 |
+
ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered
|
| 93 |
+
|
| 94 |
+
return ada_filtered_64
|
| 95 |
+
|
| 96 |
+
#----------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
class DualDiscriminator(torch.nn.Module):
|
| 99 |
+
def __init__(self,
|
| 100 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 101 |
+
img_resolution, # Input resolution.
|
| 102 |
+
img_channels, # Number of input color channels.
|
| 103 |
+
bev_channels = 0,
|
| 104 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
| 105 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
| 106 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
| 107 |
+
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
|
| 108 |
+
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 109 |
+
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
|
| 110 |
+
disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning.
|
| 111 |
+
block_kwargs = {}, # Arguments for DiscriminatorBlock.
|
| 112 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
| 113 |
+
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
|
| 114 |
+
):
|
| 115 |
+
super().__init__()
|
| 116 |
+
img_channels *= 2
|
| 117 |
+
|
| 118 |
+
self.c_dim = c_dim
|
| 119 |
+
self.img_resolution = img_resolution
|
| 120 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
| 121 |
+
self.img_channels = img_channels + bev_channels
|
| 122 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
| 123 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
| 124 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
| 125 |
+
|
| 126 |
+
if cmap_dim is None:
|
| 127 |
+
cmap_dim = channels_dict[4]
|
| 128 |
+
if c_dim == 0:
|
| 129 |
+
cmap_dim = 0
|
| 130 |
+
|
| 131 |
+
common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
| 132 |
+
cur_layer_idx = 0
|
| 133 |
+
for res in self.block_resolutions:
|
| 134 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
| 135 |
+
tmp_channels = channels_dict[res]
|
| 136 |
+
out_channels = channels_dict[res // 2]
|
| 137 |
+
use_fp16 = (res >= fp16_resolution)
|
| 138 |
+
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
| 139 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
| 140 |
+
setattr(self, f'b{res}', block)
|
| 141 |
+
cur_layer_idx += block.num_layers
|
| 142 |
+
if c_dim > 0:
|
| 143 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
|
| 144 |
+
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
|
| 145 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
|
| 146 |
+
self.disc_c_noise = disc_c_noise
|
| 147 |
+
|
| 148 |
+
def forward(self, img, c, bev=None, update_emas=False, **block_kwargs):
|
| 149 |
+
image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter)
|
| 150 |
+
img = torch.cat([img['image'], image_raw], 1)
|
| 151 |
+
if bev is not None:
|
| 152 |
+
img = torch.cat([img, bev], 1)
|
| 153 |
+
|
| 154 |
+
_ = update_emas # unused
|
| 155 |
+
x = None
|
| 156 |
+
for res in self.block_resolutions:
|
| 157 |
+
block = getattr(self, f'b{res}')
|
| 158 |
+
x, img = block(x, img, **block_kwargs)
|
| 159 |
+
|
| 160 |
+
cmap = None
|
| 161 |
+
if self.c_dim > 0:
|
| 162 |
+
if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
|
| 163 |
+
cmap = self.mapping(None, c)
|
| 164 |
+
x = self.b4(x, img, cmap)
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
def extra_repr(self):
|
| 168 |
+
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
|
| 169 |
+
|
| 170 |
+
#----------------------------------------------------------------------------
|
| 171 |
+
|
| 172 |
+
class DummyDualDiscriminator(torch.nn.Module):
|
| 173 |
+
def __init__(self,
|
| 174 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 175 |
+
img_resolution, # Input resolution.
|
| 176 |
+
img_channels, # Number of input color channels.
|
| 177 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
| 178 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
| 179 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
| 180 |
+
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
|
| 181 |
+
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 182 |
+
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
|
| 183 |
+
block_kwargs = {}, # Arguments for DiscriminatorBlock.
|
| 184 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
| 185 |
+
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
img_channels *= 2
|
| 189 |
+
|
| 190 |
+
self.c_dim = c_dim
|
| 191 |
+
self.img_resolution = img_resolution
|
| 192 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
| 193 |
+
self.img_channels = img_channels
|
| 194 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
| 195 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
| 196 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
| 197 |
+
|
| 198 |
+
if cmap_dim is None:
|
| 199 |
+
cmap_dim = channels_dict[4]
|
| 200 |
+
if c_dim == 0:
|
| 201 |
+
cmap_dim = 0
|
| 202 |
+
|
| 203 |
+
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
| 204 |
+
cur_layer_idx = 0
|
| 205 |
+
for res in self.block_resolutions:
|
| 206 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
| 207 |
+
tmp_channels = channels_dict[res]
|
| 208 |
+
out_channels = channels_dict[res // 2]
|
| 209 |
+
use_fp16 = (res >= fp16_resolution)
|
| 210 |
+
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
| 211 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
| 212 |
+
setattr(self, f'b{res}', block)
|
| 213 |
+
cur_layer_idx += block.num_layers
|
| 214 |
+
if c_dim > 0:
|
| 215 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
|
| 216 |
+
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
|
| 217 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
|
| 218 |
+
|
| 219 |
+
self.raw_fade = 1
|
| 220 |
+
|
| 221 |
+
def forward(self, img, c, update_emas=False, **block_kwargs):
|
| 222 |
+
self.raw_fade = max(0, self.raw_fade - 1/(500000/32))
|
| 223 |
+
|
| 224 |
+
image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade
|
| 225 |
+
img = torch.cat([img['image'], image_raw], 1)
|
| 226 |
+
|
| 227 |
+
_ = update_emas # unused
|
| 228 |
+
x = None
|
| 229 |
+
for res in self.block_resolutions:
|
| 230 |
+
block = getattr(self, f'b{res}')
|
| 231 |
+
x, img = block(x, img, **block_kwargs)
|
| 232 |
+
|
| 233 |
+
cmap = None
|
| 234 |
+
if self.c_dim > 0:
|
| 235 |
+
cmap = self.mapping(None, c)
|
| 236 |
+
x = self.b4(x, img, cmap)
|
| 237 |
+
return x
|
| 238 |
+
|
| 239 |
+
def extra_repr(self):
|
| 240 |
+
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
|
| 241 |
+
|
| 242 |
+
#----------------------------------------------------------------------------
|
| 243 |
+
|
models/eg3d_generator.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.8
|
| 2 |
+
"""Contains the implementation of generator described in EG3D."""
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone
|
| 7 |
+
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
|
| 8 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid2X
|
| 9 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid4X
|
| 10 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
|
| 11 |
+
from models.rendering.renderer import Renderer
|
| 12 |
+
from models.rendering.feature_extractor import FeatureExtractor
|
| 13 |
+
|
| 14 |
+
class EG3DGenerator(nn.Module):
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
z_dim, # Input latent (Z) dimensionality.
|
| 19 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 20 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 21 |
+
img_resolution, # Output resolution.
|
| 22 |
+
img_channels, # Number of output color channels.
|
| 23 |
+
sr_num_fp16_res=0, # Number of fp16 layers of SR Network.
|
| 24 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
| 25 |
+
rendering_kwargs={}, # Arguments for rendering.
|
| 26 |
+
sr_kwargs={}, # Arguments for SuperResolution Network.
|
| 27 |
+
**synthesis_kwargs, # Arguments for SynthesisNetwork.
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.z_dim = z_dim
|
| 31 |
+
self.c_dim = c_dim
|
| 32 |
+
self.w_dim = w_dim
|
| 33 |
+
self.img_resolution = img_resolution
|
| 34 |
+
self.img_channels = img_channels
|
| 35 |
+
|
| 36 |
+
# Set up the overall renderer.
|
| 37 |
+
self.renderer = Renderer()
|
| 38 |
+
|
| 39 |
+
# Set up the feature extractor.
|
| 40 |
+
self.feature_extractor = FeatureExtractor(ref_mode='tri_plane')
|
| 41 |
+
|
| 42 |
+
# Set up the reference representation generator.
|
| 43 |
+
self.backbone = StyleGAN2Backbone(z_dim,
|
| 44 |
+
c_dim,
|
| 45 |
+
w_dim,
|
| 46 |
+
img_resolution=256,
|
| 47 |
+
img_channels=32 * 3,
|
| 48 |
+
mapping_kwargs=mapping_kwargs,
|
| 49 |
+
**synthesis_kwargs)
|
| 50 |
+
|
| 51 |
+
# Set up the post module in the feature extractor.
|
| 52 |
+
self.post_module = None
|
| 53 |
+
|
| 54 |
+
# Set up the post neural renderer.
|
| 55 |
+
self.post_neural_renderer = None
|
| 56 |
+
sr_kwargs_total = dict(
|
| 57 |
+
channels=32,
|
| 58 |
+
img_resolution=img_resolution,
|
| 59 |
+
sr_num_fp16_res=sr_num_fp16_res,
|
| 60 |
+
sr_antialias=rendering_kwargs['sr_antialias'],)
|
| 61 |
+
sr_kwargs_total.update(**sr_kwargs)
|
| 62 |
+
if img_resolution == 128:
|
| 63 |
+
self.post_neural_renderer = SuperresolutionHybrid2X(
|
| 64 |
+
**sr_kwargs_total)
|
| 65 |
+
elif img_resolution == 256:
|
| 66 |
+
self.post_neural_renderer = SuperresolutionHybrid4X(
|
| 67 |
+
**sr_kwargs_total)
|
| 68 |
+
elif img_resolution == 512:
|
| 69 |
+
self.post_neural_renderer = SuperresolutionHybrid8XDC(
|
| 70 |
+
**sr_kwargs_total)
|
| 71 |
+
else:
|
| 72 |
+
raise TypeError(f'Unsupported image resolution: {img_resolution}!')
|
| 73 |
+
|
| 74 |
+
# Set up the fully-connected layer head.
|
| 75 |
+
self.fc_head = OSGDecoder(
|
| 76 |
+
32, {
|
| 77 |
+
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
|
| 78 |
+
'decoder_output_dim': 32
|
| 79 |
+
})
|
| 80 |
+
|
| 81 |
+
# Set up some rendering related arguments.
|
| 82 |
+
self.neural_rendering_resolution = rendering_kwargs.get(
|
| 83 |
+
'resolution', 64)
|
| 84 |
+
self.rendering_kwargs = rendering_kwargs
|
| 85 |
+
|
| 86 |
+
def mapping(self,
|
| 87 |
+
z,
|
| 88 |
+
c,
|
| 89 |
+
truncation_psi=1,
|
| 90 |
+
truncation_cutoff=None,
|
| 91 |
+
update_emas=False):
|
| 92 |
+
if self.rendering_kwargs['c_gen_conditioning_zero']:
|
| 93 |
+
c = torch.zeros_like(c)
|
| 94 |
+
return self.backbone.mapping(z,
|
| 95 |
+
c *
|
| 96 |
+
self.rendering_kwargs.get('c_scale', 0),
|
| 97 |
+
truncation_psi=truncation_psi,
|
| 98 |
+
truncation_cutoff=truncation_cutoff,
|
| 99 |
+
update_emas=update_emas)
|
| 100 |
+
|
| 101 |
+
def synthesis(self,
|
| 102 |
+
wp,
|
| 103 |
+
c,
|
| 104 |
+
neural_rendering_resolution=None,
|
| 105 |
+
update_emas=False,
|
| 106 |
+
**synthesis_kwargs):
|
| 107 |
+
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
| 108 |
+
if self.rendering_kwargs.get('random_pose', False):
|
| 109 |
+
cam2world_matrix = None
|
| 110 |
+
|
| 111 |
+
if neural_rendering_resolution is None:
|
| 112 |
+
neural_rendering_resolution = self.neural_rendering_resolution
|
| 113 |
+
else:
|
| 114 |
+
self.neural_rendering_resolution = neural_rendering_resolution
|
| 115 |
+
|
| 116 |
+
tri_planes = self.backbone.synthesis(wp,
|
| 117 |
+
update_emas=update_emas,
|
| 118 |
+
**synthesis_kwargs)
|
| 119 |
+
tri_planes = tri_planes.view(len(tri_planes), 3, -1,
|
| 120 |
+
tri_planes.shape[-2],
|
| 121 |
+
tri_planes.shape[-1])
|
| 122 |
+
|
| 123 |
+
rendering_result = self.renderer(
|
| 124 |
+
wp=wp,
|
| 125 |
+
feature_extractor=self.feature_extractor,
|
| 126 |
+
rendering_options=self.rendering_kwargs,
|
| 127 |
+
cam2world_matrix=cam2world_matrix,
|
| 128 |
+
position_encoder=None,
|
| 129 |
+
ref_representation=tri_planes,
|
| 130 |
+
post_module=self.post_module,
|
| 131 |
+
fc_head=self.fc_head)
|
| 132 |
+
|
| 133 |
+
feature_samples = rendering_result['composite_rgb']
|
| 134 |
+
depth_samples = rendering_result['composite_depth']
|
| 135 |
+
|
| 136 |
+
# Reshape to keep consistent with 'raw' neural-rendered image.
|
| 137 |
+
N = wp.shape[0]
|
| 138 |
+
H = W = self.neural_rendering_resolution
|
| 139 |
+
feature_image = feature_samples.permute(0, 2, 1).reshape(
|
| 140 |
+
N, feature_samples.shape[-1], H, W).contiguous()
|
| 141 |
+
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
|
| 142 |
+
|
| 143 |
+
# Run the post neural renderer to get final image.
|
| 144 |
+
# Here, the post neural renderer is a super-resolution network.
|
| 145 |
+
rgb_image = feature_image[:, :3]
|
| 146 |
+
sr_image = self.post_neural_renderer(
|
| 147 |
+
rgb_image,
|
| 148 |
+
feature_image,
|
| 149 |
+
wp,
|
| 150 |
+
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
|
| 151 |
+
**{
|
| 152 |
+
k: synthesis_kwargs[k]
|
| 153 |
+
for k in synthesis_kwargs.keys() if k != 'noise_mode'
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
return {
|
| 157 |
+
'image': sr_image,
|
| 158 |
+
'image_raw': rgb_image,
|
| 159 |
+
'image_depth': depth_image
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def sample(self,
|
| 163 |
+
coordinates,
|
| 164 |
+
directions,
|
| 165 |
+
z,
|
| 166 |
+
c,
|
| 167 |
+
truncation_psi=1,
|
| 168 |
+
truncation_cutoff=None,
|
| 169 |
+
update_emas=False,
|
| 170 |
+
**synthesis_kwargs):
|
| 171 |
+
# Compute RGB features, density for arbitrary 3D coordinates.
|
| 172 |
+
# Mostly used for extracting shapes.
|
| 173 |
+
wp = self.mapping(z,
|
| 174 |
+
c,
|
| 175 |
+
truncation_psi=truncation_psi,
|
| 176 |
+
truncation_cutoff=truncation_cutoff,
|
| 177 |
+
update_emas=update_emas)
|
| 178 |
+
tri_planes = self.backbone.synthesis(wp,
|
| 179 |
+
update_emas=update_emas,
|
| 180 |
+
**synthesis_kwargs)
|
| 181 |
+
tri_planes = tri_planes.view(len(tri_planes), 3, -1,
|
| 182 |
+
tri_planes.shape[-2],
|
| 183 |
+
tri_planes.shape[-1])
|
| 184 |
+
result = self.renderer.get_sigma_rgb(
|
| 185 |
+
wp=wp,
|
| 186 |
+
points=coordinates,
|
| 187 |
+
feature_extractor=self.feature_extractor,
|
| 188 |
+
fc_head=self.fc_head,
|
| 189 |
+
rendering_options=self.rendering_kwargs,
|
| 190 |
+
ref_representation=tri_planes,
|
| 191 |
+
post_module=self.post_module,
|
| 192 |
+
ray_dirs=directions)
|
| 193 |
+
|
| 194 |
+
return result
|
| 195 |
+
|
| 196 |
+
def sample_mixed(self,
|
| 197 |
+
coordinates,
|
| 198 |
+
directions,
|
| 199 |
+
wp,
|
| 200 |
+
truncation_psi=1,
|
| 201 |
+
truncation_cutoff=None,
|
| 202 |
+
update_emas=False,
|
| 203 |
+
**synthesis_kwargs):
|
| 204 |
+
# Same as function `self.sample()`, but expects latent vectors 'wp'
|
| 205 |
+
# instead of Gaussian noise 'z'.
|
| 206 |
+
tri_planes = self.backbone.synthesis(wp,
|
| 207 |
+
update_emas=update_emas,
|
| 208 |
+
**synthesis_kwargs)
|
| 209 |
+
tri_planes = tri_planes.view(len(tri_planes), 3, -1,
|
| 210 |
+
tri_planes.shape[-2],
|
| 211 |
+
tri_planes.shape[-1])
|
| 212 |
+
|
| 213 |
+
result = self.renderer.get_sigma_rgb(
|
| 214 |
+
wp=wp,
|
| 215 |
+
points=coordinates,
|
| 216 |
+
feature_extractor=self.feature_extractor,
|
| 217 |
+
fc_head=self.fc_head,
|
| 218 |
+
rendering_options=self.rendering_kwargs,
|
| 219 |
+
ref_representation=tri_planes,
|
| 220 |
+
post_module=self.post_module,
|
| 221 |
+
ray_dirs=directions)
|
| 222 |
+
|
| 223 |
+
return result
|
| 224 |
+
|
| 225 |
+
def forward(self,
|
| 226 |
+
z,
|
| 227 |
+
c,
|
| 228 |
+
c_swapped=None, # `c_swapped` is swapped pose conditioning.
|
| 229 |
+
style_mixing_prob=0,
|
| 230 |
+
truncation_psi=1,
|
| 231 |
+
truncation_cutoff=None,
|
| 232 |
+
neural_rendering_resolution=None,
|
| 233 |
+
update_emas=False,
|
| 234 |
+
sample_mixed=False,
|
| 235 |
+
coordinates=None,
|
| 236 |
+
**synthesis_kwargs):
|
| 237 |
+
|
| 238 |
+
# Render a batch of generated images.
|
| 239 |
+
c_wp = c.clone()
|
| 240 |
+
if c_swapped is not None:
|
| 241 |
+
c_wp = c_swapped.clone()
|
| 242 |
+
wp = self.mapping(z,
|
| 243 |
+
c_wp,
|
| 244 |
+
truncation_psi=truncation_psi,
|
| 245 |
+
truncation_cutoff=truncation_cutoff,
|
| 246 |
+
update_emas=update_emas)
|
| 247 |
+
if style_mixing_prob > 0:
|
| 248 |
+
cutoff = torch.empty([], dtype=torch.int64,
|
| 249 |
+
device=wp.device).random_(1, wp.shape[1])
|
| 250 |
+
cutoff = torch.where(
|
| 251 |
+
torch.rand([], device=wp.device) < style_mixing_prob,
|
| 252 |
+
cutoff, torch.full_like(cutoff, wp.shape[1]))
|
| 253 |
+
wp[:, cutoff:] = self.mapping(torch.randn_like(z),
|
| 254 |
+
c,
|
| 255 |
+
update_emas=update_emas)[:, cutoff:]
|
| 256 |
+
if not sample_mixed:
|
| 257 |
+
gen_output = self.synthesis(
|
| 258 |
+
wp,
|
| 259 |
+
c,
|
| 260 |
+
update_emas=update_emas,
|
| 261 |
+
neural_rendering_resolution=neural_rendering_resolution,
|
| 262 |
+
**synthesis_kwargs)
|
| 263 |
+
|
| 264 |
+
return {
|
| 265 |
+
'wp': wp,
|
| 266 |
+
'gen_output': gen_output,
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
else:
|
| 270 |
+
# Only for density regularization in training process.
|
| 271 |
+
assert coordinates is not None
|
| 272 |
+
sample_sigma = self.sample_mixed(coordinates,
|
| 273 |
+
torch.randn_like(coordinates),
|
| 274 |
+
wp,
|
| 275 |
+
update_emas=False)['sigma']
|
| 276 |
+
|
| 277 |
+
return {
|
| 278 |
+
'wp': wp,
|
| 279 |
+
'sample_sigma': sample_sigma
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class OSGDecoder(nn.Module):
|
| 284 |
+
"""Defines fully-connected layer head in EG3D."""
|
| 285 |
+
def __init__(self, n_features, options):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.hidden_dim = 64
|
| 288 |
+
|
| 289 |
+
self.net = nn.Sequential(
|
| 290 |
+
FullyConnectedLayer(n_features,
|
| 291 |
+
self.hidden_dim,
|
| 292 |
+
lr_multiplier=options['decoder_lr_mul']),
|
| 293 |
+
nn.Softplus(),
|
| 294 |
+
FullyConnectedLayer(self.hidden_dim,
|
| 295 |
+
1 + options['decoder_output_dim'],
|
| 296 |
+
lr_multiplier=options['decoder_lr_mul']))
|
| 297 |
+
|
| 298 |
+
def forward(self, point_features, wp=None, dirs=None):
|
| 299 |
+
# Aggregate features
|
| 300 |
+
# point_features.shape: [N, 3, M, C].
|
| 301 |
+
# Average across 'X, Y, Z' planes.
|
| 302 |
+
point_features = point_features.mean(1)
|
| 303 |
+
x = point_features
|
| 304 |
+
|
| 305 |
+
N, M, C = x.shape
|
| 306 |
+
x = x.view(N * M, C)
|
| 307 |
+
|
| 308 |
+
x = self.net(x)
|
| 309 |
+
x = x.view(N, M, -1)
|
| 310 |
+
|
| 311 |
+
# Uses sigmoid clamping from MipNeRF
|
| 312 |
+
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
|
| 313 |
+
sigma = x[..., 0:1]
|
| 314 |
+
|
| 315 |
+
return {'rgb': rgb, 'sigma': sigma}
|
models/eg3d_generator_fv.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.8
|
| 2 |
+
"""Contains the implementation of generator described in EG3D."""
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import numpy as np
|
| 7 |
+
from models.utils.official_stylegan2_model_helper import MappingNetwork
|
| 8 |
+
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
|
| 9 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid2X
|
| 10 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid4X
|
| 11 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
|
| 12 |
+
from models.rendering.renderer import Renderer
|
| 13 |
+
from models.rendering.feature_extractor import FeatureExtractor
|
| 14 |
+
from models.volumegan_generator import FeatureVolume
|
| 15 |
+
from models.volumegan_generator import PositionEncoder
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EG3DGeneratorFV(nn.Module):
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
# Input latent (Z) dimensionality.
|
| 23 |
+
z_dim,
|
| 24 |
+
# Conditioning label (C) dimensionality.
|
| 25 |
+
c_dim,
|
| 26 |
+
# Intermediate latent (W) dimensionality.
|
| 27 |
+
w_dim,
|
| 28 |
+
# Final output image resolution.
|
| 29 |
+
img_resolution,
|
| 30 |
+
# Number of output color channels.
|
| 31 |
+
img_channels,
|
| 32 |
+
# Number of fp16 layers of SR Network.
|
| 33 |
+
sr_num_fp16_res=0,
|
| 34 |
+
# Arguments for MappingNetwork.
|
| 35 |
+
mapping_kwargs={},
|
| 36 |
+
# Arguments for rendering.
|
| 37 |
+
rendering_kwargs={},
|
| 38 |
+
# Arguments for SuperResolution Network.
|
| 39 |
+
sr_kwargs={},
|
| 40 |
+
# Configs for FeatureVolume.
|
| 41 |
+
fv_cfg=dict(feat_res=32,
|
| 42 |
+
init_res=4,
|
| 43 |
+
base_channels=256,
|
| 44 |
+
output_channels=32,
|
| 45 |
+
w_dim=512),
|
| 46 |
+
# Configs for position encoder.
|
| 47 |
+
embed_cfg=dict(input_dim=3, max_freq_log2=10 - 1, N_freqs=10),
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.z_dim = z_dim
|
| 51 |
+
self.c_dim = c_dim
|
| 52 |
+
self.w_dim = w_dim
|
| 53 |
+
self.img_resolution = img_resolution
|
| 54 |
+
self.img_channels = img_channels
|
| 55 |
+
|
| 56 |
+
# Set up mapping network.
|
| 57 |
+
# Here `num_ws = 2`: one for FeatureVolume Network injection and one for
|
| 58 |
+
# post_neural_renderer injection.
|
| 59 |
+
num_ws = 2
|
| 60 |
+
self.mapping_network = MappingNetwork(z_dim=z_dim,
|
| 61 |
+
c_dim=c_dim,
|
| 62 |
+
w_dim=w_dim,
|
| 63 |
+
num_ws=num_ws,
|
| 64 |
+
**mapping_kwargs)
|
| 65 |
+
|
| 66 |
+
# Set up the overall renderer.
|
| 67 |
+
self.renderer = Renderer()
|
| 68 |
+
|
| 69 |
+
# Set up the feature extractor.
|
| 70 |
+
self.feature_extractor = FeatureExtractor(ref_mode='feature_volume')
|
| 71 |
+
|
| 72 |
+
# Set up the reference representation generator.
|
| 73 |
+
self.ref_representation_generator = FeatureVolume(**fv_cfg)
|
| 74 |
+
|
| 75 |
+
# Set up the position encoder.
|
| 76 |
+
self.position_encoder = PositionEncoder(**embed_cfg)
|
| 77 |
+
|
| 78 |
+
# Set up the post module in the feature extractor.
|
| 79 |
+
self.post_module = None
|
| 80 |
+
|
| 81 |
+
# Set up the post neural renderer.
|
| 82 |
+
self.post_neural_renderer = None
|
| 83 |
+
sr_kwargs_total = dict(
|
| 84 |
+
channels=32,
|
| 85 |
+
img_resolution=img_resolution,
|
| 86 |
+
sr_num_fp16_res=sr_num_fp16_res,
|
| 87 |
+
sr_antialias=rendering_kwargs['sr_antialias'],)
|
| 88 |
+
sr_kwargs_total.update(**sr_kwargs)
|
| 89 |
+
if img_resolution == 128:
|
| 90 |
+
self.post_neural_renderer = SuperresolutionHybrid2X(
|
| 91 |
+
**sr_kwargs_total)
|
| 92 |
+
elif img_resolution == 256:
|
| 93 |
+
self.post_neural_renderer = SuperresolutionHybrid4X(
|
| 94 |
+
**sr_kwargs_total)
|
| 95 |
+
elif img_resolution == 512:
|
| 96 |
+
self.post_neural_renderer = SuperresolutionHybrid8XDC(
|
| 97 |
+
**sr_kwargs_total)
|
| 98 |
+
else:
|
| 99 |
+
raise TypeError(f'Unsupported image resolution: {img_resolution}!')
|
| 100 |
+
|
| 101 |
+
# Set up the fully-connected layer head.
|
| 102 |
+
self.fc_head = OSGDecoder(
|
| 103 |
+
32, {
|
| 104 |
+
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
|
| 105 |
+
'decoder_output_dim': 32
|
| 106 |
+
})
|
| 107 |
+
|
| 108 |
+
# Set up some rendering related arguments.
|
| 109 |
+
self.neural_rendering_resolution = rendering_kwargs.get(
|
| 110 |
+
'resolution', 64)
|
| 111 |
+
self.rendering_kwargs = rendering_kwargs
|
| 112 |
+
|
| 113 |
+
def mapping(self,
|
| 114 |
+
z,
|
| 115 |
+
c,
|
| 116 |
+
truncation_psi=1,
|
| 117 |
+
truncation_cutoff=None,
|
| 118 |
+
update_emas=False):
|
| 119 |
+
if self.rendering_kwargs['c_gen_conditioning_zero']:
|
| 120 |
+
c = torch.zeros_like(c)
|
| 121 |
+
return self.mapping_network(z,
|
| 122 |
+
c *
|
| 123 |
+
self.rendering_kwargs.get('c_scale', 0),
|
| 124 |
+
truncation_psi=truncation_psi,
|
| 125 |
+
truncation_cutoff=truncation_cutoff,
|
| 126 |
+
update_emas=update_emas)
|
| 127 |
+
|
| 128 |
+
def synthesis(self,
|
| 129 |
+
wp,
|
| 130 |
+
c,
|
| 131 |
+
neural_rendering_resolution=None,
|
| 132 |
+
update_emas=False,
|
| 133 |
+
**synthesis_kwargs):
|
| 134 |
+
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
| 135 |
+
if self.rendering_kwargs.get('random_pose', False):
|
| 136 |
+
cam2world_matrix = None
|
| 137 |
+
|
| 138 |
+
if neural_rendering_resolution is None:
|
| 139 |
+
neural_rendering_resolution = self.neural_rendering_resolution
|
| 140 |
+
else:
|
| 141 |
+
self.neural_rendering_resolution = neural_rendering_resolution
|
| 142 |
+
|
| 143 |
+
feature_volume = self.ref_representation_generator(wp)
|
| 144 |
+
|
| 145 |
+
rendering_result = self.renderer(
|
| 146 |
+
wp=wp,
|
| 147 |
+
feature_extractor=self.feature_extractor,
|
| 148 |
+
rendering_options=self.rendering_kwargs,
|
| 149 |
+
cam2world_matrix=cam2world_matrix,
|
| 150 |
+
position_encoder=self.position_encoder,
|
| 151 |
+
ref_representation=feature_volume,
|
| 152 |
+
post_module=self.post_module,
|
| 153 |
+
fc_head=self.fc_head)
|
| 154 |
+
|
| 155 |
+
feature_samples = rendering_result['composite_rgb']
|
| 156 |
+
depth_samples = rendering_result['composite_depth']
|
| 157 |
+
|
| 158 |
+
# Reshape to keep consistent with 'raw' neural-rendered image.
|
| 159 |
+
N = wp.shape[0]
|
| 160 |
+
H = W = self.neural_rendering_resolution
|
| 161 |
+
feature_image = feature_samples.permute(0, 2, 1).reshape(
|
| 162 |
+
N, feature_samples.shape[-1], H, W).contiguous()
|
| 163 |
+
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
|
| 164 |
+
|
| 165 |
+
# Run the post neural renderer to get final image.
|
| 166 |
+
# Here, the post neural renderer is a super-resolution network.
|
| 167 |
+
rgb_image = feature_image[:, :3]
|
| 168 |
+
sr_image = self.post_neural_renderer(
|
| 169 |
+
rgb_image,
|
| 170 |
+
feature_image,
|
| 171 |
+
wp,
|
| 172 |
+
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
|
| 173 |
+
**{
|
| 174 |
+
k: synthesis_kwargs[k]
|
| 175 |
+
for k in synthesis_kwargs.keys() if k != 'noise_mode'
|
| 176 |
+
})
|
| 177 |
+
|
| 178 |
+
return {
|
| 179 |
+
'image': sr_image,
|
| 180 |
+
'image_raw': rgb_image,
|
| 181 |
+
'image_depth': depth_image
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
def sample(self,
|
| 185 |
+
coordinates,
|
| 186 |
+
directions,
|
| 187 |
+
z,
|
| 188 |
+
c,
|
| 189 |
+
truncation_psi=1,
|
| 190 |
+
truncation_cutoff=None,
|
| 191 |
+
update_emas=False):
|
| 192 |
+
# Compute RGB features, density for arbitrary 3D coordinates.
|
| 193 |
+
# Mostly used for extracting shapes.
|
| 194 |
+
wp = self.mapping_network(z,
|
| 195 |
+
c,
|
| 196 |
+
truncation_psi=truncation_psi,
|
| 197 |
+
truncation_cutoff=truncation_cutoff,
|
| 198 |
+
update_emas=update_emas)
|
| 199 |
+
feature_volume = self.ref_representation_generator(wp)
|
| 200 |
+
result = self.renderer.get_sigma_rgb(
|
| 201 |
+
wp=wp,
|
| 202 |
+
points=coordinates,
|
| 203 |
+
feature_extractor=self.feature_extractor,
|
| 204 |
+
fc_head=self.fc_head,
|
| 205 |
+
rendering_options=self.rendering_kwargs,
|
| 206 |
+
ref_representation=feature_volume,
|
| 207 |
+
position_encoder=self.position_encoder,
|
| 208 |
+
post_module=self.post_module,
|
| 209 |
+
ray_dirs=directions)
|
| 210 |
+
|
| 211 |
+
return result
|
| 212 |
+
|
| 213 |
+
def sample_mixed(self,
|
| 214 |
+
coordinates,
|
| 215 |
+
directions,
|
| 216 |
+
wp):
|
| 217 |
+
# Same as function `self.sample()`, but expects latent vectors 'wp'
|
| 218 |
+
# instead of Gaussian noise 'z'.
|
| 219 |
+
feature_volume = self.ref_representation_generator(wp)
|
| 220 |
+
result = self.renderer.get_sigma_rgb(
|
| 221 |
+
wp=wp,
|
| 222 |
+
points=coordinates,
|
| 223 |
+
feature_extractor=self.feature_extractor,
|
| 224 |
+
fc_head=self.fc_head,
|
| 225 |
+
rendering_options=self.rendering_kwargs,
|
| 226 |
+
ref_representation=feature_volume,
|
| 227 |
+
position_encoder=self.position_encoder,
|
| 228 |
+
post_module=self.post_module,
|
| 229 |
+
ray_dirs=directions)
|
| 230 |
+
|
| 231 |
+
return result
|
| 232 |
+
|
| 233 |
+
def forward(self,
|
| 234 |
+
z,
|
| 235 |
+
c,
|
| 236 |
+
c_swapped=None, # `c_swapped` is swapped pose conditioning.
|
| 237 |
+
style_mixing_prob=0,
|
| 238 |
+
truncation_psi=1,
|
| 239 |
+
truncation_cutoff=None,
|
| 240 |
+
neural_rendering_resolution=None,
|
| 241 |
+
update_emas=False,
|
| 242 |
+
sample_mixed=False,
|
| 243 |
+
coordinates=None,
|
| 244 |
+
**synthesis_kwargs):
|
| 245 |
+
|
| 246 |
+
# Render a batch of generated images.
|
| 247 |
+
c_wp = c.clone()
|
| 248 |
+
if c_swapped is not None:
|
| 249 |
+
c_wp = c_swapped.clone()
|
| 250 |
+
wp = self.mapping_network(z,
|
| 251 |
+
c_wp,
|
| 252 |
+
truncation_psi=truncation_psi,
|
| 253 |
+
truncation_cutoff=truncation_cutoff,
|
| 254 |
+
update_emas=update_emas)
|
| 255 |
+
if style_mixing_prob > 0:
|
| 256 |
+
cutoff = torch.empty([], dtype=torch.int64,
|
| 257 |
+
device=wp.device).random_(1, wp.shape[1])
|
| 258 |
+
cutoff = torch.where(
|
| 259 |
+
torch.rand([], device=wp.device) < style_mixing_prob, cutoff,
|
| 260 |
+
torch.full_like(cutoff, wp.shape[1]))
|
| 261 |
+
wp[:, cutoff:] = self.mapping_network(
|
| 262 |
+
torch.randn_like(z), c, update_emas=update_emas)[:, cutoff:]
|
| 263 |
+
if not sample_mixed:
|
| 264 |
+
gen_output = self.synthesis(
|
| 265 |
+
wp,
|
| 266 |
+
c,
|
| 267 |
+
update_emas=update_emas,
|
| 268 |
+
neural_rendering_resolution=neural_rendering_resolution,
|
| 269 |
+
**synthesis_kwargs)
|
| 270 |
+
|
| 271 |
+
return {
|
| 272 |
+
'wp': wp,
|
| 273 |
+
'gen_output': gen_output,
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
else:
|
| 277 |
+
# Only for density regularization in training process.
|
| 278 |
+
assert coordinates is not None
|
| 279 |
+
sample_sigma = self.sample_mixed(coordinates,
|
| 280 |
+
torch.randn_like(coordinates),
|
| 281 |
+
wp)['sigma']
|
| 282 |
+
|
| 283 |
+
return {
|
| 284 |
+
'wp': wp,
|
| 285 |
+
'sample_sigma': sample_sigma
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class OSGDecoder(nn.Module):
|
| 290 |
+
"""Defines fully-connected layer head in EG3D."""
|
| 291 |
+
def __init__(self, n_features, options):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.hidden_dim = 64
|
| 294 |
+
|
| 295 |
+
self.net = nn.Sequential(
|
| 296 |
+
FullyConnectedLayer(n_features,
|
| 297 |
+
self.hidden_dim,
|
| 298 |
+
lr_multiplier=options['decoder_lr_mul']),
|
| 299 |
+
nn.Softplus(),
|
| 300 |
+
FullyConnectedLayer(self.hidden_dim,
|
| 301 |
+
1 + options['decoder_output_dim'],
|
| 302 |
+
lr_multiplier=options['decoder_lr_mul']))
|
| 303 |
+
|
| 304 |
+
def forward(self, point_features, wp=None, dirs=None):
|
| 305 |
+
# point_features.shape: [N, C, M, 1].
|
| 306 |
+
point_features = point_features.squeeze(-1)
|
| 307 |
+
point_features = point_features.permute(0, 2, 1)
|
| 308 |
+
x = point_features
|
| 309 |
+
|
| 310 |
+
N, M, C = x.shape
|
| 311 |
+
x = x.reshape(N * M, C)
|
| 312 |
+
|
| 313 |
+
x = self.net(x)
|
| 314 |
+
x = x.reshape(N, M, -1)
|
| 315 |
+
|
| 316 |
+
# Uses sigmoid clamping from MipNeRF
|
| 317 |
+
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
|
| 318 |
+
sigma = x[..., 0:1]
|
| 319 |
+
|
| 320 |
+
return {'rgb': rgb, 'sigma': sigma}
|
models/ghfeat_encoder.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Contains the implementation of encoder used in GH-Feat (including IDInvert).
|
| 3 |
+
|
| 4 |
+
ResNet is used as the backbone.
|
| 5 |
+
|
| 6 |
+
GH-Feat paper: https://arxiv.org/pdf/2007.10379.pdf
|
| 7 |
+
IDInvert paper: https://arxiv.org/pdf/2004.00049.pdf
|
| 8 |
+
|
| 9 |
+
NOTE: Please use `latent_num` and `num_latents_per_head` to control the
|
| 10 |
+
inversion space, such as Y-space used in GH-Feat and W-space used in IDInvert.
|
| 11 |
+
In addition, IDInvert sets `use_fpn` and `use_sam` as `False` by default.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
|
| 21 |
+
__all__ = ['GHFeatEncoder']
|
| 22 |
+
|
| 23 |
+
# Resolutions allowed.
|
| 24 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 25 |
+
|
| 26 |
+
# pylint: disable=missing-function-docstring
|
| 27 |
+
|
| 28 |
+
class BasicBlock(nn.Module):
|
| 29 |
+
"""Implementation of ResNet BasicBlock."""
|
| 30 |
+
|
| 31 |
+
expansion = 1
|
| 32 |
+
|
| 33 |
+
def __init__(self,
|
| 34 |
+
inplanes,
|
| 35 |
+
planes,
|
| 36 |
+
base_width=64,
|
| 37 |
+
stride=1,
|
| 38 |
+
groups=1,
|
| 39 |
+
dilation=1,
|
| 40 |
+
norm_layer=None,
|
| 41 |
+
downsample=None):
|
| 42 |
+
super().__init__()
|
| 43 |
+
if base_width != 64:
|
| 44 |
+
raise ValueError(f'BasicBlock of ResNet only supports '
|
| 45 |
+
f'`base_width=64`, but {base_width} received!')
|
| 46 |
+
if stride not in [1, 2]:
|
| 47 |
+
raise ValueError(f'BasicBlock of ResNet only supports `stride=1` '
|
| 48 |
+
f'and `stride=2`, but {stride} received!')
|
| 49 |
+
if groups != 1:
|
| 50 |
+
raise ValueError(f'BasicBlock of ResNet only supports `groups=1`, '
|
| 51 |
+
f'but {groups} received!')
|
| 52 |
+
if dilation != 1:
|
| 53 |
+
raise ValueError(f'BasicBlock of ResNet only supports '
|
| 54 |
+
f'`dilation=1`, but {dilation} received!')
|
| 55 |
+
assert self.expansion == 1
|
| 56 |
+
|
| 57 |
+
self.stride = stride
|
| 58 |
+
if norm_layer is None:
|
| 59 |
+
norm_layer = nn.BatchNorm2d
|
| 60 |
+
self.conv1 = nn.Conv2d(in_channels=inplanes,
|
| 61 |
+
out_channels=planes,
|
| 62 |
+
kernel_size=3,
|
| 63 |
+
stride=stride,
|
| 64 |
+
padding=1,
|
| 65 |
+
groups=1,
|
| 66 |
+
dilation=1,
|
| 67 |
+
bias=False)
|
| 68 |
+
self.bn1 = norm_layer(planes)
|
| 69 |
+
self.relu = nn.ReLU(inplace=True)
|
| 70 |
+
self.conv2 = nn.Conv2d(in_channels=planes,
|
| 71 |
+
out_channels=planes,
|
| 72 |
+
kernel_size=3,
|
| 73 |
+
stride=1,
|
| 74 |
+
padding=1,
|
| 75 |
+
groups=1,
|
| 76 |
+
dilation=1,
|
| 77 |
+
bias=False)
|
| 78 |
+
self.bn2 = norm_layer(planes)
|
| 79 |
+
self.downsample = downsample
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
identity = self.downsample(x) if self.downsample is not None else x
|
| 83 |
+
|
| 84 |
+
out = self.conv1(x)
|
| 85 |
+
out = self.bn1(out)
|
| 86 |
+
out = self.relu(out)
|
| 87 |
+
|
| 88 |
+
out = self.conv2(out)
|
| 89 |
+
out = self.bn2(out)
|
| 90 |
+
out = self.relu(out + identity)
|
| 91 |
+
|
| 92 |
+
return out
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Bottleneck(nn.Module):
|
| 96 |
+
"""Implementation of ResNet Bottleneck."""
|
| 97 |
+
|
| 98 |
+
expansion = 4
|
| 99 |
+
|
| 100 |
+
def __init__(self,
|
| 101 |
+
inplanes,
|
| 102 |
+
planes,
|
| 103 |
+
base_width=64,
|
| 104 |
+
stride=1,
|
| 105 |
+
groups=1,
|
| 106 |
+
dilation=1,
|
| 107 |
+
norm_layer=None,
|
| 108 |
+
downsample=None):
|
| 109 |
+
super().__init__()
|
| 110 |
+
if stride not in [1, 2]:
|
| 111 |
+
raise ValueError(f'Bottleneck of ResNet only supports `stride=1` '
|
| 112 |
+
f'and `stride=2`, but {stride} received!')
|
| 113 |
+
|
| 114 |
+
width = int(planes * (base_width / 64)) * groups
|
| 115 |
+
self.stride = stride
|
| 116 |
+
if norm_layer is None:
|
| 117 |
+
norm_layer = nn.BatchNorm2d
|
| 118 |
+
self.conv1 = nn.Conv2d(in_channels=inplanes,
|
| 119 |
+
out_channels=width,
|
| 120 |
+
kernel_size=1,
|
| 121 |
+
stride=1,
|
| 122 |
+
padding=0,
|
| 123 |
+
dilation=1,
|
| 124 |
+
groups=1,
|
| 125 |
+
bias=False)
|
| 126 |
+
self.bn1 = norm_layer(width)
|
| 127 |
+
self.conv2 = nn.Conv2d(in_channels=width,
|
| 128 |
+
out_channels=width,
|
| 129 |
+
kernel_size=3,
|
| 130 |
+
stride=stride,
|
| 131 |
+
padding=dilation,
|
| 132 |
+
groups=groups,
|
| 133 |
+
dilation=dilation,
|
| 134 |
+
bias=False)
|
| 135 |
+
self.bn2 = norm_layer(width)
|
| 136 |
+
self.conv3 = nn.Conv2d(in_channels=width,
|
| 137 |
+
out_channels=planes * self.expansion,
|
| 138 |
+
kernel_size=1,
|
| 139 |
+
stride=1,
|
| 140 |
+
padding=0,
|
| 141 |
+
dilation=1,
|
| 142 |
+
groups=1,
|
| 143 |
+
bias=False)
|
| 144 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 145 |
+
self.relu = nn.ReLU(inplace=True)
|
| 146 |
+
self.downsample = downsample
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
identity = self.downsample(x) if self.downsample is not None else x
|
| 150 |
+
|
| 151 |
+
out = self.conv1(x)
|
| 152 |
+
out = self.bn1(out)
|
| 153 |
+
out = self.relu(out)
|
| 154 |
+
|
| 155 |
+
out = self.conv2(out)
|
| 156 |
+
out = self.bn2(out)
|
| 157 |
+
out = self.relu(out)
|
| 158 |
+
|
| 159 |
+
out = self.conv3(out)
|
| 160 |
+
out = self.bn3(out)
|
| 161 |
+
out = self.relu(out + identity)
|
| 162 |
+
|
| 163 |
+
return out
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class GHFeatEncoder(nn.Module):
|
| 167 |
+
"""Define the ResNet-based encoder network for GAN inversion.
|
| 168 |
+
|
| 169 |
+
On top of the backbone, there are several task-heads to produce inverted
|
| 170 |
+
codes. Please use `latent_dim` and `num_latents_per_head` to define the
|
| 171 |
+
structure. For example, `latent_dim = [512] * 14` and
|
| 172 |
+
`num_latents_per_head = [4, 4, 6]` can be used for StyleGAN inversion with
|
| 173 |
+
14-layer latent codes, where 3 task heads (corresponding to 4, 4, 6 layers,
|
| 174 |
+
respectively) are used.
|
| 175 |
+
|
| 176 |
+
Settings for the encoder network:
|
| 177 |
+
|
| 178 |
+
(1) resolution: The resolution of the output image.
|
| 179 |
+
(2) latent_dim: Dimension of the latent space. A number (one code will be
|
| 180 |
+
produced), or a list of numbers regarding layer-wise latent codes.
|
| 181 |
+
(3) num_latents_per_head: Number of latents that is produced by each head.
|
| 182 |
+
(4) image_channels: Number of channels of the output image. (default: 3)
|
| 183 |
+
(5) final_res: Final resolution of the convolutional layers. (default: 4)
|
| 184 |
+
|
| 185 |
+
ResNet-related settings:
|
| 186 |
+
|
| 187 |
+
(1) network_depth: Depth of the network, like 18 for ResNet18. (default: 18)
|
| 188 |
+
(2) inplanes: Number of channels of the first convolutional layer.
|
| 189 |
+
(default: 64)
|
| 190 |
+
(3) groups: Groups of the convolution, used in ResNet. (default: 1)
|
| 191 |
+
(4) width_per_group: Number of channels per group, used in ResNet.
|
| 192 |
+
(default: 64)
|
| 193 |
+
(5) replace_stride_with_dilation: Whether to replace stride with dilation,
|
| 194 |
+
used in ResNet. (default: None)
|
| 195 |
+
(6) norm_layer: Normalization layer used in the encoder. If set as `None`,
|
| 196 |
+
`nn.BatchNorm2d` will be used. Also, please NOTE that when using batch
|
| 197 |
+
normalization, the batch size is required to be larger than one for
|
| 198 |
+
training. (default: nn.BatchNorm2d)
|
| 199 |
+
(7) max_channels: Maximum number of channels in each layer. (default: 512)
|
| 200 |
+
|
| 201 |
+
Task-head related settings:
|
| 202 |
+
|
| 203 |
+
(1) use_fpn: Whether to use Feature Pyramid Network (FPN) before outputting
|
| 204 |
+
the latent code. (default: True)
|
| 205 |
+
(2) fpn_channels: Number of channels used in FPN. (default: 512)
|
| 206 |
+
(3) use_sam: Whether to use Spatial Alignment Module (SAM) before outputting
|
| 207 |
+
the latent code. (default: True)
|
| 208 |
+
(4) sam_channels: Number of channels used in SAM. (default: 512)
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
arch_settings = {
|
| 212 |
+
18: (BasicBlock, [2, 2, 2, 2]),
|
| 213 |
+
34: (BasicBlock, [3, 4, 6, 3]),
|
| 214 |
+
50: (Bottleneck, [3, 4, 6, 3]),
|
| 215 |
+
101: (Bottleneck, [3, 4, 23, 3]),
|
| 216 |
+
152: (Bottleneck, [3, 8, 36, 3])
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
def __init__(self,
|
| 220 |
+
resolution,
|
| 221 |
+
latent_dim,
|
| 222 |
+
num_latents_per_head,
|
| 223 |
+
image_channels=3,
|
| 224 |
+
final_res=4,
|
| 225 |
+
network_depth=18,
|
| 226 |
+
inplanes=64,
|
| 227 |
+
groups=1,
|
| 228 |
+
width_per_group=64,
|
| 229 |
+
replace_stride_with_dilation=None,
|
| 230 |
+
norm_layer=nn.BatchNorm2d,
|
| 231 |
+
max_channels=512,
|
| 232 |
+
use_fpn=True,
|
| 233 |
+
fpn_channels=512,
|
| 234 |
+
use_sam=True,
|
| 235 |
+
sam_channels=512):
|
| 236 |
+
super().__init__()
|
| 237 |
+
|
| 238 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
| 239 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
| 240 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
| 241 |
+
if network_depth not in self.arch_settings:
|
| 242 |
+
raise ValueError(f'Invalid network depth: `{network_depth}`!\n'
|
| 243 |
+
f'Options allowed: '
|
| 244 |
+
f'{list(self.arch_settings.keys())}.')
|
| 245 |
+
if isinstance(latent_dim, int):
|
| 246 |
+
latent_dim = [latent_dim]
|
| 247 |
+
assert isinstance(latent_dim, (list, tuple))
|
| 248 |
+
assert isinstance(num_latents_per_head, (list, tuple))
|
| 249 |
+
assert sum(num_latents_per_head) == len(latent_dim)
|
| 250 |
+
|
| 251 |
+
self.resolution = resolution
|
| 252 |
+
self.latent_dim = latent_dim
|
| 253 |
+
self.num_latents_per_head = num_latents_per_head
|
| 254 |
+
self.num_heads = len(self.num_latents_per_head)
|
| 255 |
+
self.image_channels = image_channels
|
| 256 |
+
self.final_res = final_res
|
| 257 |
+
self.inplanes = inplanes
|
| 258 |
+
self.network_depth = network_depth
|
| 259 |
+
self.groups = groups
|
| 260 |
+
self.dilation = 1
|
| 261 |
+
self.base_width = width_per_group
|
| 262 |
+
self.replace_stride_with_dilation = replace_stride_with_dilation
|
| 263 |
+
if norm_layer is None:
|
| 264 |
+
norm_layer = nn.BatchNorm2d
|
| 265 |
+
if norm_layer == nn.BatchNorm2d and dist.is_initialized():
|
| 266 |
+
norm_layer = nn.SyncBatchNorm
|
| 267 |
+
self.norm_layer = norm_layer
|
| 268 |
+
self.max_channels = max_channels
|
| 269 |
+
self.use_fpn = use_fpn
|
| 270 |
+
self.fpn_channels = fpn_channels
|
| 271 |
+
self.use_sam = use_sam
|
| 272 |
+
self.sam_channels = sam_channels
|
| 273 |
+
|
| 274 |
+
block_fn, num_blocks_per_stage = self.arch_settings[network_depth]
|
| 275 |
+
|
| 276 |
+
self.num_stages = int(np.log2(resolution // final_res)) - 1
|
| 277 |
+
# Add one block for additional stages.
|
| 278 |
+
for i in range(len(num_blocks_per_stage), self.num_stages):
|
| 279 |
+
num_blocks_per_stage.append(1)
|
| 280 |
+
if replace_stride_with_dilation is None:
|
| 281 |
+
replace_stride_with_dilation = [False] * self.num_stages
|
| 282 |
+
|
| 283 |
+
# Backbone.
|
| 284 |
+
self.conv1 = nn.Conv2d(in_channels=self.image_channels,
|
| 285 |
+
out_channels=self.inplanes,
|
| 286 |
+
kernel_size=7,
|
| 287 |
+
stride=2,
|
| 288 |
+
padding=3,
|
| 289 |
+
bias=False)
|
| 290 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 291 |
+
self.relu = nn.ReLU(inplace=True)
|
| 292 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 293 |
+
|
| 294 |
+
self.stage_channels = [self.inplanes]
|
| 295 |
+
self.stages = nn.ModuleList()
|
| 296 |
+
for i in range(self.num_stages):
|
| 297 |
+
inplanes = self.inplanes if i == 0 else planes * block_fn.expansion
|
| 298 |
+
planes = min(self.max_channels, self.inplanes * (2 ** i))
|
| 299 |
+
num_blocks = num_blocks_per_stage[i]
|
| 300 |
+
stride = 1 if i == 0 else 2
|
| 301 |
+
dilate = replace_stride_with_dilation[i]
|
| 302 |
+
self.stages.append(self._make_stage(block_fn=block_fn,
|
| 303 |
+
inplanes=inplanes,
|
| 304 |
+
planes=planes,
|
| 305 |
+
num_blocks=num_blocks,
|
| 306 |
+
stride=stride,
|
| 307 |
+
dilate=dilate))
|
| 308 |
+
self.stage_channels.append(planes * block_fn.expansion)
|
| 309 |
+
|
| 310 |
+
if self.num_heads > len(self.stage_channels):
|
| 311 |
+
raise ValueError('Number of task heads is larger than number of '
|
| 312 |
+
'stages! Please reduce the number of heads.')
|
| 313 |
+
|
| 314 |
+
# Task-head.
|
| 315 |
+
if self.num_heads == 1:
|
| 316 |
+
self.use_fpn = False
|
| 317 |
+
self.use_sam = False
|
| 318 |
+
|
| 319 |
+
if self.use_fpn:
|
| 320 |
+
fpn_pyramid_channels = self.stage_channels[-self.num_heads:]
|
| 321 |
+
self.fpn = FPN(pyramid_channels=fpn_pyramid_channels,
|
| 322 |
+
out_channels=self.fpn_channels)
|
| 323 |
+
if self.use_sam:
|
| 324 |
+
if self.use_fpn:
|
| 325 |
+
sam_pyramid_channels = [self.fpn_channels] * self.num_heads
|
| 326 |
+
else:
|
| 327 |
+
sam_pyramid_channels = self.stage_channels[-self.num_heads:]
|
| 328 |
+
self.sam = SAM(pyramid_channels=sam_pyramid_channels,
|
| 329 |
+
out_channels=self.sam_channels)
|
| 330 |
+
|
| 331 |
+
self.heads = nn.ModuleList()
|
| 332 |
+
for head_idx in range(self.num_heads):
|
| 333 |
+
# Parse in_channels.
|
| 334 |
+
if self.use_sam:
|
| 335 |
+
in_channels = self.sam_channels
|
| 336 |
+
elif self.use_fpn:
|
| 337 |
+
in_channels = self.fpn_channels
|
| 338 |
+
else:
|
| 339 |
+
in_channels = self.stage_channels[head_idx - self.num_heads]
|
| 340 |
+
in_channels = in_channels * final_res * final_res
|
| 341 |
+
|
| 342 |
+
# Parse out_channels.
|
| 343 |
+
start_latent_idx = sum(self.num_latents_per_head[:head_idx])
|
| 344 |
+
end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1])
|
| 345 |
+
out_channels = sum(self.latent_dim[start_latent_idx:end_latent_idx])
|
| 346 |
+
|
| 347 |
+
self.heads.append(CodeHead(in_channels=in_channels,
|
| 348 |
+
out_channels=out_channels,
|
| 349 |
+
norm_layer=self.norm_layer))
|
| 350 |
+
|
| 351 |
+
def _make_stage(self,
|
| 352 |
+
block_fn,
|
| 353 |
+
inplanes,
|
| 354 |
+
planes,
|
| 355 |
+
num_blocks,
|
| 356 |
+
stride,
|
| 357 |
+
dilate):
|
| 358 |
+
norm_layer = self.norm_layer
|
| 359 |
+
downsample = None
|
| 360 |
+
previous_dilation = self.dilation
|
| 361 |
+
if dilate:
|
| 362 |
+
self.dilation *= stride
|
| 363 |
+
stride = 1
|
| 364 |
+
if stride != 1 or inplanes != planes * block_fn.expansion:
|
| 365 |
+
downsample = nn.Sequential(
|
| 366 |
+
nn.Conv2d(in_channels=inplanes,
|
| 367 |
+
out_channels=planes * block_fn.expansion,
|
| 368 |
+
kernel_size=1,
|
| 369 |
+
stride=stride,
|
| 370 |
+
padding=0,
|
| 371 |
+
dilation=1,
|
| 372 |
+
groups=1,
|
| 373 |
+
bias=False),
|
| 374 |
+
norm_layer(planes * block_fn.expansion),
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
blocks = []
|
| 378 |
+
blocks.append(block_fn(inplanes=inplanes,
|
| 379 |
+
planes=planes,
|
| 380 |
+
base_width=self.base_width,
|
| 381 |
+
stride=stride,
|
| 382 |
+
groups=self.groups,
|
| 383 |
+
dilation=previous_dilation,
|
| 384 |
+
norm_layer=norm_layer,
|
| 385 |
+
downsample=downsample))
|
| 386 |
+
for _ in range(1, num_blocks):
|
| 387 |
+
blocks.append(block_fn(inplanes=planes * block_fn.expansion,
|
| 388 |
+
planes=planes,
|
| 389 |
+
base_width=self.base_width,
|
| 390 |
+
stride=1,
|
| 391 |
+
groups=self.groups,
|
| 392 |
+
dilation=self.dilation,
|
| 393 |
+
norm_layer=norm_layer,
|
| 394 |
+
downsample=None))
|
| 395 |
+
|
| 396 |
+
return nn.Sequential(*blocks)
|
| 397 |
+
|
| 398 |
+
def forward(self, x):
|
| 399 |
+
x = self.conv1(x)
|
| 400 |
+
x = self.bn1(x)
|
| 401 |
+
x = self.relu(x)
|
| 402 |
+
x = self.maxpool(x)
|
| 403 |
+
|
| 404 |
+
features = [x]
|
| 405 |
+
for i in range(self.num_stages):
|
| 406 |
+
x = self.stages[i](x)
|
| 407 |
+
features.append(x)
|
| 408 |
+
features = features[-self.num_heads:]
|
| 409 |
+
|
| 410 |
+
if self.use_fpn:
|
| 411 |
+
features = self.fpn(features)
|
| 412 |
+
if self.use_sam:
|
| 413 |
+
features = self.sam(features)
|
| 414 |
+
else:
|
| 415 |
+
final_size = features[-1].shape[2:]
|
| 416 |
+
for i in range(self.num_heads - 1):
|
| 417 |
+
features[i] = F.adaptive_avg_pool2d(features[i], final_size)
|
| 418 |
+
|
| 419 |
+
outputs = []
|
| 420 |
+
for head_idx in range(self.num_heads):
|
| 421 |
+
codes = self.heads[head_idx](features[head_idx])
|
| 422 |
+
start_latent_idx = sum(self.num_latents_per_head[:head_idx])
|
| 423 |
+
end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1])
|
| 424 |
+
split_size = self.latent_dim[start_latent_idx:end_latent_idx]
|
| 425 |
+
outputs.extend(torch.split(codes, split_size, dim=1))
|
| 426 |
+
max_dim = max(self.latent_dim)
|
| 427 |
+
for i, dim in enumerate(self.latent_dim):
|
| 428 |
+
if dim < max_dim:
|
| 429 |
+
outputs[i] = F.pad(outputs[i], (0, max_dim - dim))
|
| 430 |
+
outputs[i] = outputs[i].unsqueeze(1)
|
| 431 |
+
|
| 432 |
+
return torch.cat(outputs, dim=1)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class FPN(nn.Module):
|
| 436 |
+
"""Implementation of Feature Pyramid Network (FPN).
|
| 437 |
+
|
| 438 |
+
The input of this module is a pyramid of features with reducing resolutions.
|
| 439 |
+
Then, this module fuses these multi-level features from `top_level` to
|
| 440 |
+
`bottom_level`. In particular, starting from the `top_level`, each feature
|
| 441 |
+
is convoluted, upsampled, and fused into its previous feature (which is also
|
| 442 |
+
convoluted).
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
pyramid_channels: A list of integers, each of which indicates the number
|
| 446 |
+
of channels of the feature from a particular level.
|
| 447 |
+
out_channels: Number of channels for each output.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
A list of feature maps, each of which has `out_channels` channels.
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
def __init__(self, pyramid_channels, out_channels):
|
| 454 |
+
super().__init__()
|
| 455 |
+
assert isinstance(pyramid_channels, (list, tuple))
|
| 456 |
+
self.num_levels = len(pyramid_channels)
|
| 457 |
+
|
| 458 |
+
self.lateral_layers = nn.ModuleList()
|
| 459 |
+
self.feature_layers = nn.ModuleList()
|
| 460 |
+
for i in range(self.num_levels):
|
| 461 |
+
in_channels = pyramid_channels[i]
|
| 462 |
+
self.lateral_layers.append(nn.Conv2d(in_channels=in_channels,
|
| 463 |
+
out_channels=out_channels,
|
| 464 |
+
kernel_size=3,
|
| 465 |
+
padding=1,
|
| 466 |
+
bias=True))
|
| 467 |
+
self.feature_layers.append(nn.Conv2d(in_channels=out_channels,
|
| 468 |
+
out_channels=out_channels,
|
| 469 |
+
kernel_size=3,
|
| 470 |
+
padding=1,
|
| 471 |
+
bias=True))
|
| 472 |
+
|
| 473 |
+
def forward(self, inputs):
|
| 474 |
+
if len(inputs) != self.num_levels:
|
| 475 |
+
raise ValueError('Number of inputs and `num_levels` mismatch!')
|
| 476 |
+
|
| 477 |
+
# Project all related features to `out_channels`.
|
| 478 |
+
laterals = []
|
| 479 |
+
for i in range(self.num_levels):
|
| 480 |
+
laterals.append(self.lateral_layers[i](inputs[i]))
|
| 481 |
+
|
| 482 |
+
# Fusion, starting from `top_level`.
|
| 483 |
+
for i in range(self.num_levels - 1, 0, -1):
|
| 484 |
+
scale_factor = laterals[i - 1].shape[2] // laterals[i].shape[2]
|
| 485 |
+
laterals[i - 1] = (laterals[i - 1] +
|
| 486 |
+
F.interpolate(laterals[i],
|
| 487 |
+
mode='nearest',
|
| 488 |
+
scale_factor=scale_factor))
|
| 489 |
+
|
| 490 |
+
# Get outputs.
|
| 491 |
+
outputs = []
|
| 492 |
+
for i, lateral in enumerate(laterals):
|
| 493 |
+
outputs.append(self.feature_layers[i](lateral))
|
| 494 |
+
|
| 495 |
+
return outputs
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class SAM(nn.Module):
|
| 499 |
+
"""Implementation of Spatial Alignment Module (SAM).
|
| 500 |
+
|
| 501 |
+
The input of this module is a pyramid of features with reducing resolutions.
|
| 502 |
+
Then this module downsamples all levels of feature to the minimum resolution
|
| 503 |
+
and fuses it with the smallest feature map.
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
pyramid_channels: A list of integers, each of which indicates the number
|
| 507 |
+
of channels of the feature from a particular level.
|
| 508 |
+
out_channels: Number of channels for each output.
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
A list of feature maps, each of which has `out_channels` channels.
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
def __init__(self, pyramid_channels, out_channels):
|
| 515 |
+
super().__init__()
|
| 516 |
+
assert isinstance(pyramid_channels, (list, tuple))
|
| 517 |
+
self.num_levels = len(pyramid_channels)
|
| 518 |
+
|
| 519 |
+
self.fusion_layers = nn.ModuleList()
|
| 520 |
+
for i in range(self.num_levels):
|
| 521 |
+
in_channels = pyramid_channels[i]
|
| 522 |
+
self.fusion_layers.append(nn.Conv2d(in_channels=in_channels,
|
| 523 |
+
out_channels=out_channels,
|
| 524 |
+
kernel_size=3,
|
| 525 |
+
padding=1,
|
| 526 |
+
bias=True))
|
| 527 |
+
|
| 528 |
+
def forward(self, inputs):
|
| 529 |
+
if len(inputs) != self.num_levels:
|
| 530 |
+
raise ValueError('Number of inputs and `num_levels` mismatch!')
|
| 531 |
+
|
| 532 |
+
output_res = inputs[-1].shape[2:]
|
| 533 |
+
for i in range(self.num_levels - 1, -1, -1):
|
| 534 |
+
if i != self.num_levels - 1:
|
| 535 |
+
inputs[i] = F.adaptive_avg_pool2d(inputs[i], output_res)
|
| 536 |
+
inputs[i] = self.fusion_layers[i](inputs[i])
|
| 537 |
+
if i != self.num_levels - 1:
|
| 538 |
+
inputs[i] = inputs[i] + inputs[-1]
|
| 539 |
+
|
| 540 |
+
return inputs
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class CodeHead(nn.Module):
|
| 544 |
+
"""Implementation of the task-head to produce inverted codes."""
|
| 545 |
+
|
| 546 |
+
def __init__(self, in_channels, out_channels, norm_layer):
|
| 547 |
+
super().__init__()
|
| 548 |
+
self.fc = nn.Linear(in_channels, out_channels, bias=True)
|
| 549 |
+
if norm_layer is None:
|
| 550 |
+
self.norm = nn.Identity()
|
| 551 |
+
else:
|
| 552 |
+
self.norm = norm_layer(out_channels)
|
| 553 |
+
|
| 554 |
+
def forward(self, x):
|
| 555 |
+
if x.ndim > 2:
|
| 556 |
+
x = x.flatten(start_dim=1)
|
| 557 |
+
latent = self.fc(x)
|
| 558 |
+
latent = latent.unsqueeze(2).unsqueeze(3)
|
| 559 |
+
latent = self.norm(latent)
|
| 560 |
+
|
| 561 |
+
return latent.flatten(start_dim=1)
|
| 562 |
+
|
| 563 |
+
# pylint: enable=missing-function-docstring
|
models/inception_model.py
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Contains the Inception V3 model, which is used for inference ONLY.
|
| 3 |
+
|
| 4 |
+
This file is mostly borrowed from `torchvision/models/inception.py`.
|
| 5 |
+
|
| 6 |
+
Inception model is widely used to compute FID or IS metric for evaluating
|
| 7 |
+
generative models. However, the pre-trained models from torchvision is slightly
|
| 8 |
+
different from the TensorFlow version
|
| 9 |
+
|
| 10 |
+
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
| 11 |
+
|
| 12 |
+
which is used by the official FID implementation
|
| 13 |
+
|
| 14 |
+
https://github.com/bioinf-jku/TTUR
|
| 15 |
+
|
| 16 |
+
In particular:
|
| 17 |
+
|
| 18 |
+
(1) The number of classes in TensorFlow model is 1008 instead of 1000.
|
| 19 |
+
(2) The avg_pool() layers in TensorFlow model does not include the padded zero.
|
| 20 |
+
(3) The last Inception E Block in TensorFlow model use max_pool() instead of
|
| 21 |
+
avg_pool().
|
| 22 |
+
|
| 23 |
+
Hence, to align the evaluation results with those from TensorFlow
|
| 24 |
+
implementation, we modified the inception model to support both versions. Please
|
| 25 |
+
use `align_tf` argument to control the version.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import warnings
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
import torch.distributed as dist
|
| 34 |
+
|
| 35 |
+
from utils.misc import download_url
|
| 36 |
+
|
| 37 |
+
__all__ = ['InceptionModel']
|
| 38 |
+
|
| 39 |
+
# pylint: disable=line-too-long
|
| 40 |
+
|
| 41 |
+
_MODEL_URL_SHA256 = {
|
| 42 |
+
# This model is provided by `torchvision`, which is ported from TensorFlow.
|
| 43 |
+
'torchvision_official': (
|
| 44 |
+
'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
|
| 45 |
+
'1a9a5a14f40645a370184bd54f4e8e631351e71399112b43ad0294a79da290c8' # hash sha256
|
| 46 |
+
),
|
| 47 |
+
|
| 48 |
+
# This model is provided by https://github.com/mseitzer/pytorch-fid
|
| 49 |
+
'tf_inception_v3': (
|
| 50 |
+
'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth',
|
| 51 |
+
'6726825d0af5f729cebd5821db510b11b1cfad8faad88a03f1befd49fb9129b2' # hash sha256
|
| 52 |
+
)
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class InceptionModel(object):
|
| 57 |
+
"""Defines the Inception (V3) model.
|
| 58 |
+
|
| 59 |
+
This is a static class, which is used to avoid this model to be built
|
| 60 |
+
repeatedly. Consequently, this model is particularly used for inference,
|
| 61 |
+
like computing FID. If training is required, please use the model from
|
| 62 |
+
`torchvision.models` or implement by yourself.
|
| 63 |
+
|
| 64 |
+
NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
|
| 65 |
+
order and pixel range [-1, 1], and will also resize the images to shape
|
| 66 |
+
[299, 299] automatically. If your input is normalized by subtracting
|
| 67 |
+
(0.485, 0.456, 0.406) and dividing (0.229, 0.224, 0.225), please use
|
| 68 |
+
`transform_input` in the `forward()` function to un-normalize it.
|
| 69 |
+
"""
|
| 70 |
+
models = dict()
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def build_model(align_tf=True):
|
| 74 |
+
"""Builds the model and load pre-trained weights.
|
| 75 |
+
|
| 76 |
+
If `align_tf` is set as True, the model will predict 1008 classes, and
|
| 77 |
+
the pre-trained weight from `https://github.com/mseitzer/pytorch-fid`
|
| 78 |
+
will be loaded. Otherwise, the model will predict 1000 classes, and will
|
| 79 |
+
load the model from `torchvision`.
|
| 80 |
+
|
| 81 |
+
The built model supports following arguments when forwarding:
|
| 82 |
+
|
| 83 |
+
- transform_input: Whether to transform the input back to pixel range
|
| 84 |
+
(-1, 1). Please disable this argument if your input is already with
|
| 85 |
+
pixel range (-1, 1). (default: False)
|
| 86 |
+
- output_logits: Whether to output the categorical logits instead of
|
| 87 |
+
features. (default: False)
|
| 88 |
+
- remove_logits_bias: Whether to remove the bias when computing the
|
| 89 |
+
logits. The official implementation removes the bias by default.
|
| 90 |
+
Please refer to
|
| 91 |
+
`https://github.com/openai/improved-gan/blob/master/inception_score/model.py`.
|
| 92 |
+
(default: False)
|
| 93 |
+
- output_predictions: Whether to output the final predictions, i.e.,
|
| 94 |
+
`softmax(logits)`. (default: False)
|
| 95 |
+
"""
|
| 96 |
+
if align_tf:
|
| 97 |
+
num_classes = 1008
|
| 98 |
+
model_source = 'tf_inception_v3'
|
| 99 |
+
else:
|
| 100 |
+
num_classes = 1000
|
| 101 |
+
model_source = 'torchvision_official'
|
| 102 |
+
|
| 103 |
+
fingerprint = model_source
|
| 104 |
+
|
| 105 |
+
if fingerprint not in InceptionModel.models:
|
| 106 |
+
# Build model.
|
| 107 |
+
model = Inception3(num_classes=num_classes,
|
| 108 |
+
aux_logits=False,
|
| 109 |
+
init_weights=False,
|
| 110 |
+
align_tf=align_tf)
|
| 111 |
+
|
| 112 |
+
# Download pre-trained weights.
|
| 113 |
+
if dist.is_initialized() and dist.get_rank() != 0:
|
| 114 |
+
dist.barrier() # Download by chief.
|
| 115 |
+
|
| 116 |
+
url, sha256 = _MODEL_URL_SHA256[model_source]
|
| 117 |
+
filename = f'inception_model_{model_source}_{sha256}.pth'
|
| 118 |
+
model_path, hash_check = download_url(url,
|
| 119 |
+
filename=filename,
|
| 120 |
+
sha256=sha256)
|
| 121 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
| 122 |
+
if hash_check is False:
|
| 123 |
+
warnings.warn(f'Hash check failed! The remote file from URL '
|
| 124 |
+
f'`{url}` may be changed, or the downloading is '
|
| 125 |
+
f'interrupted. The loaded inception model may '
|
| 126 |
+
f'have unexpected behavior.')
|
| 127 |
+
|
| 128 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
| 129 |
+
dist.barrier() # Wait for other replicas.
|
| 130 |
+
|
| 131 |
+
# Load weights.
|
| 132 |
+
model.load_state_dict(state_dict, strict=False)
|
| 133 |
+
del state_dict
|
| 134 |
+
|
| 135 |
+
# For inference only.
|
| 136 |
+
model.eval().requires_grad_(False).cuda()
|
| 137 |
+
InceptionModel.models[fingerprint] = model
|
| 138 |
+
|
| 139 |
+
return InceptionModel.models[fingerprint]
|
| 140 |
+
|
| 141 |
+
# pylint: disable=missing-function-docstring
|
| 142 |
+
# pylint: disable=missing-class-docstring
|
| 143 |
+
# pylint: disable=super-with-arguments
|
| 144 |
+
# pylint: disable=consider-merging-isinstance
|
| 145 |
+
# pylint: disable=import-outside-toplevel
|
| 146 |
+
# pylint: disable=no-else-return
|
| 147 |
+
|
| 148 |
+
class Inception3(nn.Module):
|
| 149 |
+
|
| 150 |
+
def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None,
|
| 151 |
+
init_weights=True, align_tf=True):
|
| 152 |
+
super(Inception3, self).__init__()
|
| 153 |
+
if inception_blocks is None:
|
| 154 |
+
inception_blocks = [
|
| 155 |
+
BasicConv2d, InceptionA, InceptionB, InceptionC,
|
| 156 |
+
InceptionD, InceptionE, InceptionAux
|
| 157 |
+
]
|
| 158 |
+
assert len(inception_blocks) == 7
|
| 159 |
+
conv_block = inception_blocks[0]
|
| 160 |
+
inception_a = inception_blocks[1]
|
| 161 |
+
inception_b = inception_blocks[2]
|
| 162 |
+
inception_c = inception_blocks[3]
|
| 163 |
+
inception_d = inception_blocks[4]
|
| 164 |
+
inception_e = inception_blocks[5]
|
| 165 |
+
inception_aux = inception_blocks[6]
|
| 166 |
+
|
| 167 |
+
self.aux_logits = aux_logits
|
| 168 |
+
self.align_tf = align_tf
|
| 169 |
+
self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
|
| 170 |
+
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
|
| 171 |
+
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
|
| 172 |
+
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
|
| 173 |
+
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
|
| 174 |
+
self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf)
|
| 175 |
+
self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf)
|
| 176 |
+
self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf)
|
| 177 |
+
self.Mixed_6a = inception_b(288)
|
| 178 |
+
self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf)
|
| 179 |
+
self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
|
| 180 |
+
self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
|
| 181 |
+
self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf)
|
| 182 |
+
if aux_logits:
|
| 183 |
+
self.AuxLogits = inception_aux(768, num_classes)
|
| 184 |
+
self.Mixed_7a = inception_d(768)
|
| 185 |
+
self.Mixed_7b = inception_e(1280, align_tf=self.align_tf)
|
| 186 |
+
self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf)
|
| 187 |
+
self.fc = nn.Linear(2048, num_classes)
|
| 188 |
+
if init_weights:
|
| 189 |
+
for m in self.modules():
|
| 190 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 191 |
+
import scipy.stats as stats
|
| 192 |
+
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
|
| 193 |
+
X = stats.truncnorm(-2, 2, scale=stddev)
|
| 194 |
+
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
|
| 195 |
+
values = values.view(m.weight.size())
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
m.weight.copy_(values)
|
| 198 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 199 |
+
nn.init.constant_(m.weight, 1)
|
| 200 |
+
nn.init.constant_(m.bias, 0)
|
| 201 |
+
|
| 202 |
+
@staticmethod
|
| 203 |
+
def _transform_input(x, transform_input=False):
|
| 204 |
+
if transform_input:
|
| 205 |
+
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
| 206 |
+
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
| 207 |
+
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
| 208 |
+
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
def _forward(self,
|
| 212 |
+
x,
|
| 213 |
+
output_logits=False,
|
| 214 |
+
remove_logits_bias=False,
|
| 215 |
+
output_predictions=False):
|
| 216 |
+
# Upsample if necessary.
|
| 217 |
+
if x.shape[2] != 299 or x.shape[3] != 299:
|
| 218 |
+
if self.align_tf:
|
| 219 |
+
theta = torch.eye(2, 3).to(x)
|
| 220 |
+
theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 299
|
| 221 |
+
theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 299
|
| 222 |
+
theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
| 223 |
+
grid = F.affine_grid(theta,
|
| 224 |
+
size=(x.shape[0], x.shape[1], 299, 299),
|
| 225 |
+
align_corners=False)
|
| 226 |
+
x = F.grid_sample(x, grid,
|
| 227 |
+
mode='bilinear',
|
| 228 |
+
padding_mode='border',
|
| 229 |
+
align_corners=False)
|
| 230 |
+
else:
|
| 231 |
+
x = F.interpolate(
|
| 232 |
+
x, size=(299, 299), mode='bilinear', align_corners=False)
|
| 233 |
+
if x.shape[1] == 1:
|
| 234 |
+
x = x.repeat((1, 3, 1, 1))
|
| 235 |
+
|
| 236 |
+
if self.align_tf:
|
| 237 |
+
x = (x * 127.5 + 127.5 - 128) / 128
|
| 238 |
+
|
| 239 |
+
# N x 3 x 299 x 299
|
| 240 |
+
x = self.Conv2d_1a_3x3(x)
|
| 241 |
+
# N x 32 x 149 x 149
|
| 242 |
+
x = self.Conv2d_2a_3x3(x)
|
| 243 |
+
# N x 32 x 147 x 147
|
| 244 |
+
x = self.Conv2d_2b_3x3(x)
|
| 245 |
+
# N x 64 x 147 x 147
|
| 246 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
| 247 |
+
# N x 64 x 73 x 73
|
| 248 |
+
x = self.Conv2d_3b_1x1(x)
|
| 249 |
+
# N x 80 x 73 x 73
|
| 250 |
+
x = self.Conv2d_4a_3x3(x)
|
| 251 |
+
# N x 192 x 71 x 71
|
| 252 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
| 253 |
+
# N x 192 x 35 x 35
|
| 254 |
+
x = self.Mixed_5b(x)
|
| 255 |
+
# N x 256 x 35 x 35
|
| 256 |
+
x = self.Mixed_5c(x)
|
| 257 |
+
# N x 288 x 35 x 35
|
| 258 |
+
x = self.Mixed_5d(x)
|
| 259 |
+
# N x 288 x 35 x 35
|
| 260 |
+
x = self.Mixed_6a(x)
|
| 261 |
+
# N x 768 x 17 x 17
|
| 262 |
+
x = self.Mixed_6b(x)
|
| 263 |
+
# N x 768 x 17 x 17
|
| 264 |
+
x = self.Mixed_6c(x)
|
| 265 |
+
# N x 768 x 17 x 17
|
| 266 |
+
x = self.Mixed_6d(x)
|
| 267 |
+
# N x 768 x 17 x 17
|
| 268 |
+
x = self.Mixed_6e(x)
|
| 269 |
+
# N x 768 x 17 x 17
|
| 270 |
+
if self.training and self.aux_logits:
|
| 271 |
+
aux = self.AuxLogits(x)
|
| 272 |
+
else:
|
| 273 |
+
aux = None
|
| 274 |
+
# N x 768 x 17 x 17
|
| 275 |
+
x = self.Mixed_7a(x)
|
| 276 |
+
# N x 1280 x 8 x 8
|
| 277 |
+
x = self.Mixed_7b(x)
|
| 278 |
+
# N x 2048 x 8 x 8
|
| 279 |
+
x = self.Mixed_7c(x)
|
| 280 |
+
# N x 2048 x 8 x 8
|
| 281 |
+
# Adaptive average pooling
|
| 282 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
| 283 |
+
# N x 2048 x 1 x 1
|
| 284 |
+
x = F.dropout(x, training=self.training)
|
| 285 |
+
# N x 2048 x 1 x 1
|
| 286 |
+
x = torch.flatten(x, 1)
|
| 287 |
+
# N x 2048
|
| 288 |
+
if output_logits or output_predictions:
|
| 289 |
+
x = self.fc(x)
|
| 290 |
+
# N x 1000 (num_classes)
|
| 291 |
+
if remove_logits_bias:
|
| 292 |
+
x = x - self.fc.bias.view(1, -1)
|
| 293 |
+
if output_predictions:
|
| 294 |
+
x = F.softmax(x, dim=1)
|
| 295 |
+
return x, aux
|
| 296 |
+
|
| 297 |
+
def forward(self,
|
| 298 |
+
x,
|
| 299 |
+
transform_input=False,
|
| 300 |
+
output_logits=False,
|
| 301 |
+
remove_logits_bias=False,
|
| 302 |
+
output_predictions=False):
|
| 303 |
+
x = self._transform_input(x, transform_input)
|
| 304 |
+
x, aux = self._forward(
|
| 305 |
+
x, output_logits, remove_logits_bias, output_predictions)
|
| 306 |
+
if self.training and self.aux_logits:
|
| 307 |
+
return x, aux
|
| 308 |
+
else:
|
| 309 |
+
return x
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class InceptionA(nn.Module):
|
| 313 |
+
|
| 314 |
+
def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False):
|
| 315 |
+
super(InceptionA, self).__init__()
|
| 316 |
+
if conv_block is None:
|
| 317 |
+
conv_block = BasicConv2d
|
| 318 |
+
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
|
| 319 |
+
|
| 320 |
+
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
|
| 321 |
+
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
|
| 322 |
+
|
| 323 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
| 324 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
| 325 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
|
| 326 |
+
|
| 327 |
+
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
|
| 328 |
+
self.pool_include_padding = not align_tf
|
| 329 |
+
|
| 330 |
+
def _forward(self, x):
|
| 331 |
+
branch1x1 = self.branch1x1(x)
|
| 332 |
+
|
| 333 |
+
branch5x5 = self.branch5x5_1(x)
|
| 334 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
| 335 |
+
|
| 336 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 337 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 338 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
| 339 |
+
|
| 340 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
| 341 |
+
count_include_pad=self.pool_include_padding)
|
| 342 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 343 |
+
|
| 344 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
| 345 |
+
return outputs
|
| 346 |
+
|
| 347 |
+
def forward(self, x):
|
| 348 |
+
outputs = self._forward(x)
|
| 349 |
+
return torch.cat(outputs, 1)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class InceptionB(nn.Module):
|
| 353 |
+
|
| 354 |
+
def __init__(self, in_channels, conv_block=None):
|
| 355 |
+
super(InceptionB, self).__init__()
|
| 356 |
+
if conv_block is None:
|
| 357 |
+
conv_block = BasicConv2d
|
| 358 |
+
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
|
| 359 |
+
|
| 360 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
| 361 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
| 362 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
|
| 363 |
+
|
| 364 |
+
def _forward(self, x):
|
| 365 |
+
branch3x3 = self.branch3x3(x)
|
| 366 |
+
|
| 367 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 368 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 369 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
| 370 |
+
|
| 371 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
| 372 |
+
|
| 373 |
+
outputs = [branch3x3, branch3x3dbl, branch_pool]
|
| 374 |
+
return outputs
|
| 375 |
+
|
| 376 |
+
def forward(self, x):
|
| 377 |
+
outputs = self._forward(x)
|
| 378 |
+
return torch.cat(outputs, 1)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class InceptionC(nn.Module):
|
| 382 |
+
|
| 383 |
+
def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False):
|
| 384 |
+
super(InceptionC, self).__init__()
|
| 385 |
+
if conv_block is None:
|
| 386 |
+
conv_block = BasicConv2d
|
| 387 |
+
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
|
| 388 |
+
|
| 389 |
+
c7 = channels_7x7
|
| 390 |
+
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
|
| 391 |
+
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
| 392 |
+
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
|
| 393 |
+
|
| 394 |
+
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
|
| 395 |
+
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
| 396 |
+
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
| 397 |
+
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
| 398 |
+
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
| 399 |
+
|
| 400 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
| 401 |
+
self.pool_include_padding = not align_tf
|
| 402 |
+
|
| 403 |
+
def _forward(self, x):
|
| 404 |
+
branch1x1 = self.branch1x1(x)
|
| 405 |
+
|
| 406 |
+
branch7x7 = self.branch7x7_1(x)
|
| 407 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
| 408 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
| 409 |
+
|
| 410 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
| 411 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
| 412 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
| 413 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
| 414 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
| 415 |
+
|
| 416 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
| 417 |
+
count_include_pad=self.pool_include_padding)
|
| 418 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 419 |
+
|
| 420 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
| 421 |
+
return outputs
|
| 422 |
+
|
| 423 |
+
def forward(self, x):
|
| 424 |
+
outputs = self._forward(x)
|
| 425 |
+
return torch.cat(outputs, 1)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class InceptionD(nn.Module):
|
| 429 |
+
|
| 430 |
+
def __init__(self, in_channels, conv_block=None):
|
| 431 |
+
super(InceptionD, self).__init__()
|
| 432 |
+
if conv_block is None:
|
| 433 |
+
conv_block = BasicConv2d
|
| 434 |
+
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
| 435 |
+
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
|
| 436 |
+
|
| 437 |
+
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
| 438 |
+
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
|
| 439 |
+
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
|
| 440 |
+
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
|
| 441 |
+
|
| 442 |
+
def _forward(self, x):
|
| 443 |
+
branch3x3 = self.branch3x3_1(x)
|
| 444 |
+
branch3x3 = self.branch3x3_2(branch3x3)
|
| 445 |
+
|
| 446 |
+
branch7x7x3 = self.branch7x7x3_1(x)
|
| 447 |
+
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
|
| 448 |
+
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
|
| 449 |
+
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
|
| 450 |
+
|
| 451 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
| 452 |
+
outputs = [branch3x3, branch7x7x3, branch_pool]
|
| 453 |
+
return outputs
|
| 454 |
+
|
| 455 |
+
def forward(self, x):
|
| 456 |
+
outputs = self._forward(x)
|
| 457 |
+
return torch.cat(outputs, 1)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class InceptionE(nn.Module):
|
| 461 |
+
|
| 462 |
+
def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False):
|
| 463 |
+
super(InceptionE, self).__init__()
|
| 464 |
+
if conv_block is None:
|
| 465 |
+
conv_block = BasicConv2d
|
| 466 |
+
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
|
| 467 |
+
|
| 468 |
+
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
|
| 469 |
+
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
| 470 |
+
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
| 471 |
+
|
| 472 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
|
| 473 |
+
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
|
| 474 |
+
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
| 475 |
+
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
| 476 |
+
|
| 477 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
| 478 |
+
self.pool_include_padding = not align_tf
|
| 479 |
+
self.use_max_pool = use_max_pool
|
| 480 |
+
|
| 481 |
+
def _forward(self, x):
|
| 482 |
+
branch1x1 = self.branch1x1(x)
|
| 483 |
+
|
| 484 |
+
branch3x3 = self.branch3x3_1(x)
|
| 485 |
+
branch3x3 = [
|
| 486 |
+
self.branch3x3_2a(branch3x3),
|
| 487 |
+
self.branch3x3_2b(branch3x3),
|
| 488 |
+
]
|
| 489 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
| 490 |
+
|
| 491 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 492 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 493 |
+
branch3x3dbl = [
|
| 494 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
| 495 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
| 496 |
+
]
|
| 497 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
| 498 |
+
|
| 499 |
+
if self.use_max_pool:
|
| 500 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
| 501 |
+
else:
|
| 502 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
| 503 |
+
count_include_pad=self.pool_include_padding)
|
| 504 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 505 |
+
|
| 506 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
| 507 |
+
return outputs
|
| 508 |
+
|
| 509 |
+
def forward(self, x):
|
| 510 |
+
outputs = self._forward(x)
|
| 511 |
+
return torch.cat(outputs, 1)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class InceptionAux(nn.Module):
|
| 515 |
+
|
| 516 |
+
def __init__(self, in_channels, num_classes, conv_block=None):
|
| 517 |
+
super(InceptionAux, self).__init__()
|
| 518 |
+
if conv_block is None:
|
| 519 |
+
conv_block = BasicConv2d
|
| 520 |
+
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
|
| 521 |
+
self.conv1 = conv_block(128, 768, kernel_size=5)
|
| 522 |
+
self.conv1.stddev = 0.01
|
| 523 |
+
self.fc = nn.Linear(768, num_classes)
|
| 524 |
+
self.fc.stddev = 0.001
|
| 525 |
+
|
| 526 |
+
def forward(self, x):
|
| 527 |
+
# N x 768 x 17 x 17
|
| 528 |
+
x = F.avg_pool2d(x, kernel_size=5, stride=3)
|
| 529 |
+
# N x 768 x 5 x 5
|
| 530 |
+
x = self.conv0(x)
|
| 531 |
+
# N x 128 x 5 x 5
|
| 532 |
+
x = self.conv1(x)
|
| 533 |
+
# N x 768 x 1 x 1
|
| 534 |
+
# Adaptive average pooling
|
| 535 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
| 536 |
+
# N x 768 x 1 x 1
|
| 537 |
+
x = torch.flatten(x, 1)
|
| 538 |
+
# N x 768
|
| 539 |
+
x = self.fc(x)
|
| 540 |
+
# N x 1000
|
| 541 |
+
return x
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class BasicConv2d(nn.Module):
|
| 545 |
+
|
| 546 |
+
def __init__(self, in_channels, out_channels, **kwargs):
|
| 547 |
+
super(BasicConv2d, self).__init__()
|
| 548 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
| 549 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
| 550 |
+
|
| 551 |
+
def forward(self, x):
|
| 552 |
+
x = self.conv(x)
|
| 553 |
+
x = self.bn(x)
|
| 554 |
+
return F.relu(x, inplace=True)
|
| 555 |
+
|
| 556 |
+
# pylint: enable=line-too-long
|
| 557 |
+
# pylint: enable=missing-function-docstring
|
| 558 |
+
# pylint: enable=missing-class-docstring
|
| 559 |
+
# pylint: enable=super-with-arguments
|
| 560 |
+
# pylint: enable=consider-merging-isinstance
|
| 561 |
+
# pylint: enable=import-outside-toplevel
|
| 562 |
+
# pylint: enable=no-else-return
|
models/perceptual_model.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3.7
|
| 2 |
+
"""Contains the VGG16 model, which is used for inference ONLY.
|
| 3 |
+
|
| 4 |
+
VGG16 is commonly used for perceptual feature extraction. The model implemented
|
| 5 |
+
in this file can be used for evaluation (like computing LPIPS, perceptual path
|
| 6 |
+
length, etc.), OR be used in training for loss computation (like perceptual
|
| 7 |
+
loss, etc.).
|
| 8 |
+
|
| 9 |
+
The pre-trained model is officially shared by
|
| 10 |
+
|
| 11 |
+
https://www.robots.ox.ac.uk/~vgg/research/very_deep/
|
| 12 |
+
|
| 13 |
+
and ported by
|
| 14 |
+
|
| 15 |
+
https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt
|
| 16 |
+
|
| 17 |
+
Compared to the official VGG16 model, this ported model also support evaluating
|
| 18 |
+
LPIPS, which is introduced in
|
| 19 |
+
|
| 20 |
+
https://github.com/richzhang/PerceptualSimilarity
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import warnings
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import torch.distributed as dist
|
| 30 |
+
|
| 31 |
+
from utils.misc import download_url
|
| 32 |
+
|
| 33 |
+
__all__ = ['PerceptualModel']
|
| 34 |
+
|
| 35 |
+
# pylint: disable=line-too-long
|
| 36 |
+
_MODEL_URL_SHA256 = {
|
| 37 |
+
# This model is provided by `torchvision`, which is ported from TensorFlow.
|
| 38 |
+
'torchvision_official': (
|
| 39 |
+
'https://download.pytorch.org/models/vgg16-397923af.pth',
|
| 40 |
+
'397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0' # hash sha256
|
| 41 |
+
),
|
| 42 |
+
|
| 43 |
+
# This model is provided by https://github.com/NVlabs/stylegan2-ada-pytorch
|
| 44 |
+
'vgg_perceptual_lpips': (
|
| 45 |
+
'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt',
|
| 46 |
+
'b437eb095feaeb0b83eb3fa11200ebca4548ee39a07fb944a417ddc516cc07c3' # hash sha256
|
| 47 |
+
)
|
| 48 |
+
}
|
| 49 |
+
# pylint: enable=line-too-long
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class PerceptualModel(object):
|
| 53 |
+
"""Defines the perceptual model, which is based on VGG16 structure.
|
| 54 |
+
|
| 55 |
+
This is a static class, which is used to avoid this model to be built
|
| 56 |
+
repeatedly. Consequently, this model is particularly used for inference,
|
| 57 |
+
like computing LPIPS, or for loss computation, like perceptual loss. If
|
| 58 |
+
training is required, please use the model from `torchvision.models` or
|
| 59 |
+
implement by yourself.
|
| 60 |
+
|
| 61 |
+
NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
|
| 62 |
+
order and pixel range [-1, 1], and will NOT resize the input automatically
|
| 63 |
+
if only perceptual feature is needed.
|
| 64 |
+
"""
|
| 65 |
+
models = dict()
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def build_model(use_torchvision=False, no_top=True, enable_lpips=True):
|
| 69 |
+
"""Builds the model and load pre-trained weights.
|
| 70 |
+
|
| 71 |
+
1. If `use_torchvision` is set as True, the model released by
|
| 72 |
+
`torchvision` will be loaded, otherwise, the model released by
|
| 73 |
+
https://www.robots.ox.ac.uk/~vgg/research/very_deep/ will be used.
|
| 74 |
+
(default: False)
|
| 75 |
+
|
| 76 |
+
2. To save computing resources, these is an option to only load the
|
| 77 |
+
backbone (i.e., without the last three fully-connected layers). This
|
| 78 |
+
is commonly used for perceptual loss or LPIPS loss computation.
|
| 79 |
+
Please use argument `no_top` to control this. (default: True)
|
| 80 |
+
|
| 81 |
+
3. For LPIPS loss computation, some additional weights (which is used
|
| 82 |
+
for balancing the features from different resolutions) are employed
|
| 83 |
+
on top of the original VGG16 backbone. Details can be found at
|
| 84 |
+
https://github.com/richzhang/PerceptualSimilarity. Please use
|
| 85 |
+
`enable_lpips` to enable this feature. (default: True)
|
| 86 |
+
|
| 87 |
+
The built model supports following arguments when forwarding:
|
| 88 |
+
|
| 89 |
+
- resize_input: Whether to resize the input image to size [224, 224]
|
| 90 |
+
before forwarding. For feature-based computation (i.e., only
|
| 91 |
+
convolutional layers are used), image resizing is not essential.
|
| 92 |
+
(default: False)
|
| 93 |
+
- return_tensor: This field resolves the model behavior. Following
|
| 94 |
+
options are supported:
|
| 95 |
+
`feature1`: Before the first max pooling layer.
|
| 96 |
+
`pool1`: After the first max pooling layer.
|
| 97 |
+
`feature2`: Before the second max pooling layer.
|
| 98 |
+
`pool2`: After the second max pooling layer.
|
| 99 |
+
`feature3`: Before the third max pooling layer.
|
| 100 |
+
`pool3`: After the third max pooling layer.
|
| 101 |
+
`feature4`: Before the fourth max pooling layer.
|
| 102 |
+
`pool4`: After the fourth max pooling layer.
|
| 103 |
+
`feature5`: Before the fifth max pooling layer.
|
| 104 |
+
`pool5`: After the fifth max pooling layer.
|
| 105 |
+
`flatten`: The flattened feature, after `adaptive_avgpool`.
|
| 106 |
+
`feature`: The 4096d feature for logits computation. (default)
|
| 107 |
+
`logits`: The 1000d categorical logits.
|
| 108 |
+
`prediction`: The 1000d predicted probability.
|
| 109 |
+
`lpips`: The LPIPS score between two input images.
|
| 110 |
+
"""
|
| 111 |
+
if use_torchvision:
|
| 112 |
+
model_source = 'torchvision_official'
|
| 113 |
+
align_tf_resize = False
|
| 114 |
+
is_torch_script = False
|
| 115 |
+
else:
|
| 116 |
+
model_source = 'vgg_perceptual_lpips'
|
| 117 |
+
align_tf_resize = True
|
| 118 |
+
is_torch_script = True
|
| 119 |
+
|
| 120 |
+
if enable_lpips and model_source != 'vgg_perceptual_lpips':
|
| 121 |
+
warnings.warn('The pre-trained model officially released by '
|
| 122 |
+
'`torchvision` does not support LPIPS computation! '
|
| 123 |
+
'Equal weights will be used for each resolution.')
|
| 124 |
+
|
| 125 |
+
fingerprint = (model_source, no_top, enable_lpips)
|
| 126 |
+
|
| 127 |
+
if fingerprint not in PerceptualModel.models:
|
| 128 |
+
# Build model.
|
| 129 |
+
model = VGG16(align_tf_resize=align_tf_resize,
|
| 130 |
+
no_top=no_top,
|
| 131 |
+
enable_lpips=enable_lpips)
|
| 132 |
+
|
| 133 |
+
# Download pre-trained weights.
|
| 134 |
+
if dist.is_initialized() and dist.get_rank() != 0:
|
| 135 |
+
dist.barrier() # Download by chief.
|
| 136 |
+
|
| 137 |
+
url, sha256 = _MODEL_URL_SHA256[model_source]
|
| 138 |
+
filename = f'perceptual_model_{model_source}_{sha256}.pth'
|
| 139 |
+
model_path, hash_check = download_url(url,
|
| 140 |
+
filename=filename,
|
| 141 |
+
sha256=sha256)
|
| 142 |
+
if is_torch_script:
|
| 143 |
+
src_state_dict = torch.jit.load(model_path, map_location='cpu')
|
| 144 |
+
else:
|
| 145 |
+
src_state_dict = torch.load(model_path, map_location='cpu')
|
| 146 |
+
if hash_check is False:
|
| 147 |
+
warnings.warn(f'Hash check failed! The remote file from URL '
|
| 148 |
+
f'`{url}` may be changed, or the downloading is '
|
| 149 |
+
f'interrupted. The loaded perceptual model may '
|
| 150 |
+
f'have unexpected behavior.')
|
| 151 |
+
|
| 152 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
| 153 |
+
dist.barrier() # Wait for other replicas.
|
| 154 |
+
|
| 155 |
+
# Load weights.
|
| 156 |
+
dst_state_dict = _convert_weights(src_state_dict, model_source)
|
| 157 |
+
model.load_state_dict(dst_state_dict, strict=False)
|
| 158 |
+
del src_state_dict, dst_state_dict
|
| 159 |
+
|
| 160 |
+
# For inference only.
|
| 161 |
+
model.eval().requires_grad_(False).cuda()
|
| 162 |
+
PerceptualModel.models[fingerprint] = model
|
| 163 |
+
|
| 164 |
+
return PerceptualModel.models[fingerprint]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _convert_weights(src_state_dict, model_source):
|
| 168 |
+
if model_source not in _MODEL_URL_SHA256:
|
| 169 |
+
raise ValueError(f'Invalid model source `{model_source}`!\n'
|
| 170 |
+
f'Sources allowed: {list(_MODEL_URL_SHA256.keys())}.')
|
| 171 |
+
if model_source == 'torchvision_official':
|
| 172 |
+
dst_to_src_var_mapping = {
|
| 173 |
+
'conv11.weight': 'features.0.weight',
|
| 174 |
+
'conv11.bias': 'features.0.bias',
|
| 175 |
+
'conv12.weight': 'features.2.weight',
|
| 176 |
+
'conv12.bias': 'features.2.bias',
|
| 177 |
+
'conv21.weight': 'features.5.weight',
|
| 178 |
+
'conv21.bias': 'features.5.bias',
|
| 179 |
+
'conv22.weight': 'features.7.weight',
|
| 180 |
+
'conv22.bias': 'features.7.bias',
|
| 181 |
+
'conv31.weight': 'features.10.weight',
|
| 182 |
+
'conv31.bias': 'features.10.bias',
|
| 183 |
+
'conv32.weight': 'features.12.weight',
|
| 184 |
+
'conv32.bias': 'features.12.bias',
|
| 185 |
+
'conv33.weight': 'features.14.weight',
|
| 186 |
+
'conv33.bias': 'features.14.bias',
|
| 187 |
+
'conv41.weight': 'features.17.weight',
|
| 188 |
+
'conv41.bias': 'features.17.bias',
|
| 189 |
+
'conv42.weight': 'features.19.weight',
|
| 190 |
+
'conv42.bias': 'features.19.bias',
|
| 191 |
+
'conv43.weight': 'features.21.weight',
|
| 192 |
+
'conv43.bias': 'features.21.bias',
|
| 193 |
+
'conv51.weight': 'features.24.weight',
|
| 194 |
+
'conv51.bias': 'features.24.bias',
|
| 195 |
+
'conv52.weight': 'features.26.weight',
|
| 196 |
+
'conv52.bias': 'features.26.bias',
|
| 197 |
+
'conv53.weight': 'features.28.weight',
|
| 198 |
+
'conv53.bias': 'features.28.bias',
|
| 199 |
+
'fc1.weight': 'classifier.0.weight',
|
| 200 |
+
'fc1.bias': 'classifier.0.bias',
|
| 201 |
+
'fc2.weight': 'classifier.3.weight',
|
| 202 |
+
'fc2.bias': 'classifier.3.bias',
|
| 203 |
+
'fc3.weight': 'classifier.6.weight',
|
| 204 |
+
'fc3.bias': 'classifier.6.bias',
|
| 205 |
+
}
|
| 206 |
+
elif model_source == 'vgg_perceptual_lpips':
|
| 207 |
+
src_state_dict = src_state_dict.state_dict()
|
| 208 |
+
dst_to_src_var_mapping = {
|
| 209 |
+
'conv11.weight': 'layers.conv1.weight',
|
| 210 |
+
'conv11.bias': 'layers.conv1.bias',
|
| 211 |
+
'conv12.weight': 'layers.conv2.weight',
|
| 212 |
+
'conv12.bias': 'layers.conv2.bias',
|
| 213 |
+
'conv21.weight': 'layers.conv3.weight',
|
| 214 |
+
'conv21.bias': 'layers.conv3.bias',
|
| 215 |
+
'conv22.weight': 'layers.conv4.weight',
|
| 216 |
+
'conv22.bias': 'layers.conv4.bias',
|
| 217 |
+
'conv31.weight': 'layers.conv5.weight',
|
| 218 |
+
'conv31.bias': 'layers.conv5.bias',
|
| 219 |
+
'conv32.weight': 'layers.conv6.weight',
|
| 220 |
+
'conv32.bias': 'layers.conv6.bias',
|
| 221 |
+
'conv33.weight': 'layers.conv7.weight',
|
| 222 |
+
'conv33.bias': 'layers.conv7.bias',
|
| 223 |
+
'conv41.weight': 'layers.conv8.weight',
|
| 224 |
+
'conv41.bias': 'layers.conv8.bias',
|
| 225 |
+
'conv42.weight': 'layers.conv9.weight',
|
| 226 |
+
'conv42.bias': 'layers.conv9.bias',
|
| 227 |
+
'conv43.weight': 'layers.conv10.weight',
|
| 228 |
+
'conv43.bias': 'layers.conv10.bias',
|
| 229 |
+
'conv51.weight': 'layers.conv11.weight',
|
| 230 |
+
'conv51.bias': 'layers.conv11.bias',
|
| 231 |
+
'conv52.weight': 'layers.conv12.weight',
|
| 232 |
+
'conv52.bias': 'layers.conv12.bias',
|
| 233 |
+
'conv53.weight': 'layers.conv13.weight',
|
| 234 |
+
'conv53.bias': 'layers.conv13.bias',
|
| 235 |
+
'fc1.weight': 'layers.fc1.weight',
|
| 236 |
+
'fc1.bias': 'layers.fc1.bias',
|
| 237 |
+
'fc2.weight': 'layers.fc2.weight',
|
| 238 |
+
'fc2.bias': 'layers.fc2.bias',
|
| 239 |
+
'fc3.weight': 'layers.fc3.weight',
|
| 240 |
+
'fc3.bias': 'layers.fc3.bias',
|
| 241 |
+
'lpips.0.weight': 'lpips0',
|
| 242 |
+
'lpips.1.weight': 'lpips1',
|
| 243 |
+
'lpips.2.weight': 'lpips2',
|
| 244 |
+
'lpips.3.weight': 'lpips3',
|
| 245 |
+
'lpips.4.weight': 'lpips4',
|
| 246 |
+
}
|
| 247 |
+
else:
|
| 248 |
+
raise NotImplementedError(f'Not implemented model source '
|
| 249 |
+
f'`{model_source}`!')
|
| 250 |
+
|
| 251 |
+
dst_state_dict = {}
|
| 252 |
+
for dst_name, src_name in dst_to_src_var_mapping.items():
|
| 253 |
+
if dst_name.startswith('lpips'):
|
| 254 |
+
dst_state_dict[dst_name] = src_state_dict[src_name].unsqueeze(0)
|
| 255 |
+
else:
|
| 256 |
+
dst_state_dict[dst_name] = src_state_dict[src_name].clone()
|
| 257 |
+
return dst_state_dict
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
_IMG_MEAN = (0.485, 0.456, 0.406)
|
| 261 |
+
_IMG_STD = (0.229, 0.224, 0.225)
|
| 262 |
+
_ALLOWED_RETURN = [
|
| 263 |
+
'feature1', 'pool1', 'feature2', 'pool2', 'feature3', 'pool3', 'feature4',
|
| 264 |
+
'pool4', 'feature5', 'pool5', 'flatten', 'feature', 'logits', 'prediction',
|
| 265 |
+
'lpips'
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
# pylint: disable=missing-function-docstring
|
| 269 |
+
|
| 270 |
+
class VGG16(nn.Module):
|
| 271 |
+
"""Defines the VGG16 structure.
|
| 272 |
+
|
| 273 |
+
This model takes `RGB` images with data format `NCHW` as the raw inputs. The
|
| 274 |
+
pixel range are assumed to be [-1, 1].
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
def __init__(self, align_tf_resize=False, no_top=True, enable_lpips=True):
|
| 278 |
+
"""Defines the network structure."""
|
| 279 |
+
super().__init__()
|
| 280 |
+
|
| 281 |
+
self.align_tf_resize = align_tf_resize
|
| 282 |
+
self.no_top = no_top
|
| 283 |
+
self.enable_lpips = enable_lpips
|
| 284 |
+
|
| 285 |
+
self.conv11 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
| 286 |
+
self.relu11 = nn.ReLU(inplace=True)
|
| 287 |
+
self.conv12 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 288 |
+
self.relu12 = nn.ReLU(inplace=True)
|
| 289 |
+
# output `feature1`, with shape [N, 64, 224, 224]
|
| 290 |
+
|
| 291 |
+
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 292 |
+
# output `pool1`, with shape [N, 64, 112, 112]
|
| 293 |
+
|
| 294 |
+
self.conv21 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
| 295 |
+
self.relu21 = nn.ReLU(inplace=True)
|
| 296 |
+
self.conv22 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
| 297 |
+
self.relu22 = nn.ReLU(inplace=True)
|
| 298 |
+
# output `feature2`, with shape [N, 128, 112, 112]
|
| 299 |
+
|
| 300 |
+
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 301 |
+
# output `pool2`, with shape [N, 128, 56, 56]
|
| 302 |
+
|
| 303 |
+
self.conv31 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
| 304 |
+
self.relu31 = nn.ReLU(inplace=True)
|
| 305 |
+
self.conv32 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 306 |
+
self.relu32 = nn.ReLU(inplace=True)
|
| 307 |
+
self.conv33 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 308 |
+
self.relu33 = nn.ReLU(inplace=True)
|
| 309 |
+
# output `feature3`, with shape [N, 256, 56, 56]
|
| 310 |
+
|
| 311 |
+
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 312 |
+
# output `pool3`, with shape [N,256, 28, 28]
|
| 313 |
+
|
| 314 |
+
self.conv41 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
| 315 |
+
self.relu41 = nn.ReLU(inplace=True)
|
| 316 |
+
self.conv42 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 317 |
+
self.relu42 = nn.ReLU(inplace=True)
|
| 318 |
+
self.conv43 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 319 |
+
self.relu43 = nn.ReLU(inplace=True)
|
| 320 |
+
# output `feature4`, with shape [N, 512, 28, 28]
|
| 321 |
+
|
| 322 |
+
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 323 |
+
# output `pool4`, with shape [N, 512, 14, 14]
|
| 324 |
+
|
| 325 |
+
self.conv51 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 326 |
+
self.relu51 = nn.ReLU(inplace=True)
|
| 327 |
+
self.conv52 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 328 |
+
self.relu52 = nn.ReLU(inplace=True)
|
| 329 |
+
self.conv53 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 330 |
+
self.relu53 = nn.ReLU(inplace=True)
|
| 331 |
+
# output `feature5`, with shape [N, 512, 14, 14]
|
| 332 |
+
|
| 333 |
+
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 334 |
+
# output `pool5`, with shape [N, 512, 7, 7]
|
| 335 |
+
|
| 336 |
+
if self.enable_lpips:
|
| 337 |
+
self.lpips = nn.ModuleList()
|
| 338 |
+
for idx, ch in enumerate([64, 128, 256, 512, 512]):
|
| 339 |
+
self.lpips.append(nn.Conv2d(ch, 1, kernel_size=1, bias=False))
|
| 340 |
+
self.lpips[idx].weight.data.copy_(torch.ones(1, ch, 1, 1))
|
| 341 |
+
|
| 342 |
+
if not self.no_top:
|
| 343 |
+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
| 344 |
+
self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
|
| 345 |
+
# output `flatten`, with shape [N, 25088]
|
| 346 |
+
|
| 347 |
+
self.fc1 = nn.Linear(512 * 7 * 7, 4096)
|
| 348 |
+
self.fc1_relu = nn.ReLU(inplace=True)
|
| 349 |
+
self.fc1_dropout = nn.Dropout(0.5, inplace=False)
|
| 350 |
+
self.fc2 = nn.Linear(4096, 4096)
|
| 351 |
+
self.fc2_relu = nn.ReLU(inplace=True)
|
| 352 |
+
self.fc2_dropout = nn.Dropout(0.5, inplace=False)
|
| 353 |
+
# output `feature`, with shape [N, 4096]
|
| 354 |
+
|
| 355 |
+
self.fc3 = nn.Linear(4096, 1000)
|
| 356 |
+
# output `logits`, with shape [N, 1000]
|
| 357 |
+
|
| 358 |
+
self.out = nn.Softmax(dim=1)
|
| 359 |
+
# output `softmax`, with shape [N, 1000]
|
| 360 |
+
|
| 361 |
+
img_mean = np.array(_IMG_MEAN).reshape((1, 3, 1, 1)).astype(np.float32)
|
| 362 |
+
img_std = np.array(_IMG_STD).reshape((1, 3, 1, 1)).astype(np.float32)
|
| 363 |
+
self.register_buffer('img_mean', torch.from_numpy(img_mean))
|
| 364 |
+
self.register_buffer('img_std', torch.from_numpy(img_std))
|
| 365 |
+
|
| 366 |
+
def forward(self,
|
| 367 |
+
x,
|
| 368 |
+
y=None,
|
| 369 |
+
*,
|
| 370 |
+
resize_input=False,
|
| 371 |
+
return_tensor='feature'):
|
| 372 |
+
return_tensor = return_tensor.lower()
|
| 373 |
+
if return_tensor not in _ALLOWED_RETURN:
|
| 374 |
+
raise ValueError(f'Invalid output tensor name `{return_tensor}` '
|
| 375 |
+
f'for perceptual model (VGG16)!\n'
|
| 376 |
+
f'Names allowed: {_ALLOWED_RETURN}.')
|
| 377 |
+
|
| 378 |
+
if return_tensor == 'lpips' and y is None:
|
| 379 |
+
raise ValueError('Two images are required for LPIPS computation, '
|
| 380 |
+
'but only one is received!')
|
| 381 |
+
|
| 382 |
+
if return_tensor == 'lpips':
|
| 383 |
+
assert x.shape == y.shape
|
| 384 |
+
x = torch.cat([x, y], dim=0)
|
| 385 |
+
features = []
|
| 386 |
+
|
| 387 |
+
if resize_input:
|
| 388 |
+
if self.align_tf_resize:
|
| 389 |
+
theta = torch.eye(2, 3).to(x)
|
| 390 |
+
theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 224
|
| 391 |
+
theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 224
|
| 392 |
+
theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
| 393 |
+
grid = F.affine_grid(theta,
|
| 394 |
+
size=(x.shape[0], x.shape[1], 224, 224),
|
| 395 |
+
align_corners=False)
|
| 396 |
+
x = F.grid_sample(x, grid,
|
| 397 |
+
mode='bilinear',
|
| 398 |
+
padding_mode='border',
|
| 399 |
+
align_corners=False)
|
| 400 |
+
else:
|
| 401 |
+
x = F.interpolate(x,
|
| 402 |
+
size=(224, 224),
|
| 403 |
+
mode='bilinear',
|
| 404 |
+
align_corners=False)
|
| 405 |
+
if x.shape[1] == 1:
|
| 406 |
+
x = x.repeat((1, 3, 1, 1))
|
| 407 |
+
|
| 408 |
+
x = (x + 1) / 2
|
| 409 |
+
x = (x - self.img_mean) / self.img_std
|
| 410 |
+
|
| 411 |
+
x = self.conv11(x)
|
| 412 |
+
x = self.relu11(x)
|
| 413 |
+
x = self.conv12(x)
|
| 414 |
+
x = self.relu12(x)
|
| 415 |
+
if return_tensor == 'feature1':
|
| 416 |
+
return x
|
| 417 |
+
if return_tensor == 'lpips':
|
| 418 |
+
features.append(x)
|
| 419 |
+
|
| 420 |
+
x = self.pool1(x)
|
| 421 |
+
if return_tensor == 'pool1':
|
| 422 |
+
return x
|
| 423 |
+
|
| 424 |
+
x = self.conv21(x)
|
| 425 |
+
x = self.relu21(x)
|
| 426 |
+
x = self.conv22(x)
|
| 427 |
+
x = self.relu22(x)
|
| 428 |
+
if return_tensor == 'feature2':
|
| 429 |
+
return x
|
| 430 |
+
if return_tensor == 'lpips':
|
| 431 |
+
features.append(x)
|
| 432 |
+
|
| 433 |
+
x = self.pool2(x)
|
| 434 |
+
if return_tensor == 'pool2':
|
| 435 |
+
return x
|
| 436 |
+
|
| 437 |
+
x = self.conv31(x)
|
| 438 |
+
x = self.relu31(x)
|
| 439 |
+
x = self.conv32(x)
|
| 440 |
+
x = self.relu32(x)
|
| 441 |
+
x = self.conv33(x)
|
| 442 |
+
x = self.relu33(x)
|
| 443 |
+
if return_tensor == 'feature3':
|
| 444 |
+
return x
|
| 445 |
+
if return_tensor == 'lpips':
|
| 446 |
+
features.append(x)
|
| 447 |
+
|
| 448 |
+
x = self.pool3(x)
|
| 449 |
+
if return_tensor == 'pool3':
|
| 450 |
+
return x
|
| 451 |
+
|
| 452 |
+
x = self.conv41(x)
|
| 453 |
+
x = self.relu41(x)
|
| 454 |
+
x = self.conv42(x)
|
| 455 |
+
x = self.relu42(x)
|
| 456 |
+
x = self.conv43(x)
|
| 457 |
+
x = self.relu43(x)
|
| 458 |
+
if return_tensor == 'feature4':
|
| 459 |
+
return x
|
| 460 |
+
if return_tensor == 'lpips':
|
| 461 |
+
features.append(x)
|
| 462 |
+
|
| 463 |
+
x = self.pool4(x)
|
| 464 |
+
if return_tensor == 'pool4':
|
| 465 |
+
return x
|
| 466 |
+
|
| 467 |
+
x = self.conv51(x)
|
| 468 |
+
x = self.relu51(x)
|
| 469 |
+
x = self.conv52(x)
|
| 470 |
+
x = self.relu52(x)
|
| 471 |
+
x = self.conv53(x)
|
| 472 |
+
x = self.relu53(x)
|
| 473 |
+
if return_tensor == 'feature5':
|
| 474 |
+
return x
|
| 475 |
+
if return_tensor == 'lpips':
|
| 476 |
+
features.append(x)
|
| 477 |
+
|
| 478 |
+
x = self.pool5(x)
|
| 479 |
+
if return_tensor == 'pool5':
|
| 480 |
+
return x
|
| 481 |
+
|
| 482 |
+
if return_tensor == 'lpips':
|
| 483 |
+
score = 0
|
| 484 |
+
assert len(features) == 5
|
| 485 |
+
for idx in range(5):
|
| 486 |
+
feature = features[idx]
|
| 487 |
+
norm = feature.norm(dim=1, keepdim=True)
|
| 488 |
+
feature = feature / (norm + 1e-10)
|
| 489 |
+
feature_x, feature_y = feature.chunk(2, dim=0)
|
| 490 |
+
diff = (feature_x - feature_y).square()
|
| 491 |
+
score += self.lpips[idx](diff).mean(dim=(2, 3), keepdim=False)
|
| 492 |
+
return score.sum(dim=1, keepdim=False)
|
| 493 |
+
|
| 494 |
+
x = self.avgpool(x)
|
| 495 |
+
x = self.flatten(x)
|
| 496 |
+
if return_tensor == 'flatten':
|
| 497 |
+
return x
|
| 498 |
+
|
| 499 |
+
x = self.fc1(x)
|
| 500 |
+
x = self.fc1_relu(x)
|
| 501 |
+
x = self.fc1_dropout(x)
|
| 502 |
+
x = self.fc2(x)
|
| 503 |
+
x = self.fc2_relu(x)
|
| 504 |
+
x = self.fc2_dropout(x)
|
| 505 |
+
if return_tensor == 'feature':
|
| 506 |
+
return x
|
| 507 |
+
|
| 508 |
+
x = self.fc3(x)
|
| 509 |
+
if return_tensor == 'logits':
|
| 510 |
+
return x
|
| 511 |
+
|
| 512 |
+
x = self.out(x)
|
| 513 |
+
if return_tensor == 'prediction':
|
| 514 |
+
return x
|
| 515 |
+
|
| 516 |
+
raise NotImplementedError(f'Output tensor name `{return_tensor}` is '
|
| 517 |
+
f'not implemented!')
|
| 518 |
+
|
| 519 |
+
# pylint: enable=missing-function-docstring
|