Commit
·
9a973f2
1
Parent(s):
8f7f7c3
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- README.md +70 -0
- checkpoints/renderer_checkpoint.pt +3 -0
- checkpoints/styletalk_checkpoint.pth +3 -0
- configs/__pycache__/default.cpython-37.pyc +0 -0
- configs/default.py +65 -0
- configs/renderer_conf.yaml +17 -0
- core/__pycache__/utils.cpython-37.pyc +0 -0
- core/networks/__init__.py +9 -0
- core/networks/__pycache__/__init__.cpython-37.pyc +0 -0
- core/networks/__pycache__/disentangle_decoder.cpython-37.pyc +0 -0
- core/networks/__pycache__/dynamic_conv.cpython-37.pyc +0 -0
- core/networks/__pycache__/dynamic_fc_decoder.cpython-37.pyc +0 -0
- core/networks/__pycache__/dynamic_linear.cpython-37.pyc +0 -0
- core/networks/__pycache__/generator.cpython-37.pyc +0 -0
- core/networks/__pycache__/mish.cpython-37.pyc +0 -0
- core/networks/__pycache__/self_attention_pooling.cpython-37.pyc +0 -0
- core/networks/__pycache__/styletalk.cpython-37.pyc +0 -0
- core/networks/__pycache__/transformer.cpython-37.pyc +0 -0
- core/networks/building_blocks.py +112 -0
- core/networks/disentangle_decoder.py +184 -0
- core/networks/dynamic_conv.py +149 -0
- core/networks/dynamic_fc_decoder.py +140 -0
- core/networks/dynamic_linear.py +42 -0
- core/networks/generator.py +213 -0
- core/networks/mish.py +51 -0
- core/networks/self_attention_pooling.py +43 -0
- core/networks/styletalk.py +24 -0
- core/networks/transformer.py +300 -0
- core/utils.py +228 -0
- demo.mp4 +0 -0
- demo.npy +3 -0
- demo_download.mp4 +0 -0
- demo_download.npy +3 -0
- env.yaml +0 -0
- environment.yml +91 -0
- generators/__pycache__/base_function.cpython-37.pyc +0 -0
- generators/__pycache__/face_model.cpython-37.pyc +0 -0
- generators/__pycache__/flow_util.cpython-37.pyc +0 -0
- generators/base_function.py +368 -0
- generators/face_model.py +127 -0
- generators/flow_util.py +56 -0
- inference_for_demo.py +187 -0
- media/first_page.png +3 -0
- phindex.json +1 -0
- requirements.txt +11 -0
- samples/source_video/3DMM/KristiNoem.mat +0 -0
- samples/source_video/3DMM/Obama_clip1.mat +0 -0
- samples/source_video/3DMM/Obama_clip2.mat +0 -0
- samples/source_video/3DMM/Obama_clip3.mat +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
media/first_page.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
samples/source_video/wav/intro.wav filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# StyleTalk
|
3 |
+
|
4 |
+
The official repository of the AAAI2023 paper [StyleTalk: One-shot Talking Head Generation with Controllable Speaking Styles](https://arxiv.org/abs/2301.01081)
|
5 |
+
|
6 |
+
<p align='center'>
|
7 |
+
<b>
|
8 |
+
<a href="https://arxiv.org/abs/2301.01081">Paper</a>
|
9 |
+
|
|
10 |
+
<a href="https://drive.google.com/file/d/19WRhBHYVWRIH8_zo332l00fLXfUE96-k/view?usp=share_link">Supp. Materials</a>
|
11 |
+
|
|
12 |
+
<a href="https://youtu.be/mO2Tjcwr4u8">Video</a>
|
13 |
+
</b>
|
14 |
+
</p>
|
15 |
+
|
16 |
+
<p align='center'>
|
17 |
+
<img src='media/first_page.png' width='700'/>
|
18 |
+
</p>
|
19 |
+
|
20 |
+
The proposed **StyleTalk** can generate talking head videos with speaking styles specified by arbitrary style reference videos.
|
21 |
+
|
22 |
+
# News
|
23 |
+
* April 14th, 2023. The code is available.
|
24 |
+
|
25 |
+
# Get Started
|
26 |
+
|
27 |
+
## Installation
|
28 |
+
|
29 |
+
Clone this repo, install conda and run:
|
30 |
+
|
31 |
+
```bash
|
32 |
+
conda create -n styletalk python=3.7.0
|
33 |
+
conda activate styletalk
|
34 |
+
pip install -r requirements.txt
|
35 |
+
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
|
36 |
+
conda update ffmpeg
|
37 |
+
```
|
38 |
+
|
39 |
+
The code has been test on CUDA 11.1, GPU RTX 3090.
|
40 |
+
|
41 |
+
## Data Preprocessing
|
42 |
+
Our methods takes 3DMM parameters(\*.mat) and phoneme labels(\*_seq.json) as input. Follow [PIRenderer](https://github.com/RenYurui/PIRender) to extract 3DMM parameters. Follow [AVCT](https://github.com/FuxiVirtualHuman/AAAI22-one-shot-talking-face) to extract phoneme labels. Some preprocessed data can be found in folder `samples`.
|
43 |
+
|
44 |
+
|
45 |
+
## Inference
|
46 |
+
Download checkpoints for [StyleTalk](https://drive.google.com/file/d/1z54FymEiyPQ0mPGrVePt8GMtDe-E2RmN/view?usp=share_link) and [Renderer](https://drive.google.com/file/d/1wFAtFQjybKI3hwRWvtcBDl4tpZzlDkja/view?usp=share_link) and put them into `./checkpoints`.
|
47 |
+
|
48 |
+
Run the demo:
|
49 |
+
|
50 |
+
```bash
|
51 |
+
python inference_for_demo.py \
|
52 |
+
--audio_path samples/source_video/phoneme/reagan_clip1_seq.json \
|
53 |
+
--style_clip_path samples/style_clips/3DMM/happyenglish_clip1.mat \
|
54 |
+
--pose_path samples/source_video/3DMM/reagan_clip1.mat \
|
55 |
+
--src_img_path samples/source_video/image/andrew_clip_1.png \
|
56 |
+
--wav_path samples/source_video/wav/reagan_clip1.wav \
|
57 |
+
--output_path demo.mp4
|
58 |
+
```
|
59 |
+
|
60 |
+
Change `audio_path`, `style_clip_path`, `pose_path`, `src_img_path`, `wav_path`, `output_path` to generate more results.
|
61 |
+
|
62 |
+
# Acknowledgement
|
63 |
+
Some code are borrowed from following projects:
|
64 |
+
* [AVCT](https://github.com/FuxiVirtualHuman/AAAI22-one-shot-talking-face)
|
65 |
+
* [PIRenderer](https://github.com/RenYurui/PIRender)
|
66 |
+
* [Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch)
|
67 |
+
* [Speech Drives Templates](https://github.com/ShenhanQian/SpeechDrivesTemplates)
|
68 |
+
* [FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing)
|
69 |
+
|
70 |
+
Thanks for their contributions!
|
checkpoints/renderer_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a67014839d42d592255c9fc3b3ceecbcd62c27ce0c0a89ed6628292447404242
|
3 |
+
size 335281551
|
checkpoints/styletalk_checkpoint.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3bd52d0080d52440e25de44930b539d21c7f1431102c2811fafb30838e9812e
|
3 |
+
size 275485145
|
configs/__pycache__/default.cpython-37.pyc
ADDED
Binary file (1.92 kB). View file
|
|
configs/default.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
|
4 |
+
_C = CN()
|
5 |
+
_C.TAG = "style_id_emotion"
|
6 |
+
_C.DECODER_TYPE = "DisentangleDecoder"
|
7 |
+
_C.CONTENT_ENCODER_TYPE = "ContentEncoder"
|
8 |
+
_C.STYLE_ENCODER_TYPE = "StyleEncoder"
|
9 |
+
_C.DISCRIMINATOR_TYPE = "Discriminator"
|
10 |
+
|
11 |
+
|
12 |
+
_C.WIN_SIZE = 5
|
13 |
+
_C.D_MODEL = 256
|
14 |
+
|
15 |
+
_C.DATASET = CN()
|
16 |
+
_C.DATASET.FACE3D_DIM = 64
|
17 |
+
|
18 |
+
_C.CONTENT_ENCODER = CN()
|
19 |
+
_C.CONTENT_ENCODER.d_model = _C.D_MODEL
|
20 |
+
_C.CONTENT_ENCODER.nhead = 8
|
21 |
+
_C.CONTENT_ENCODER.num_encoder_layers = 3
|
22 |
+
_C.CONTENT_ENCODER.dim_feedforward = 4 * _C.D_MODEL
|
23 |
+
_C.CONTENT_ENCODER.dropout = 0.1
|
24 |
+
_C.CONTENT_ENCODER.activation = "relu"
|
25 |
+
_C.CONTENT_ENCODER.normalize_before = False
|
26 |
+
_C.CONTENT_ENCODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
|
27 |
+
_C.CONTENT_ENCODER.ph_embed_dim = 128
|
28 |
+
|
29 |
+
_C.STYLE_ENCODER = CN()
|
30 |
+
_C.STYLE_ENCODER.d_model = _C.D_MODEL
|
31 |
+
_C.STYLE_ENCODER.nhead = 8
|
32 |
+
_C.STYLE_ENCODER.num_encoder_layers = 3
|
33 |
+
_C.STYLE_ENCODER.dim_feedforward = 4 * _C.D_MODEL
|
34 |
+
_C.STYLE_ENCODER.dropout = 0.1
|
35 |
+
_C.STYLE_ENCODER.activation = "relu"
|
36 |
+
_C.STYLE_ENCODER.normalize_before = False
|
37 |
+
_C.STYLE_ENCODER.pos_embed_len = 256
|
38 |
+
_C.STYLE_ENCODER.aggregate_method = "self_attention_pooling" # average | self_attention_pooling
|
39 |
+
# _C.STYLE_ENCODER.input_dim = _C.DATASET.FACE3D_DIM
|
40 |
+
|
41 |
+
_C.DECODER = CN()
|
42 |
+
_C.DECODER.d_model = _C.D_MODEL
|
43 |
+
_C.DECODER.nhead = 8
|
44 |
+
_C.DECODER.num_decoder_layers = 3
|
45 |
+
_C.DECODER.dim_feedforward = 4 * _C.D_MODEL
|
46 |
+
_C.DECODER.dropout = 0.1
|
47 |
+
_C.DECODER.activation = "relu"
|
48 |
+
_C.DECODER.normalize_before = False
|
49 |
+
_C.DECODER.return_intermediate_dec = False
|
50 |
+
_C.DECODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
|
51 |
+
_C.DECODER.network_type = "DynamicFCDecoder"
|
52 |
+
_C.DECODER.dynamic_K = 8
|
53 |
+
_C.DECODER.dynamic_ratio = 4
|
54 |
+
# fmt: off
|
55 |
+
_C.DECODER.upper_face3d_indices = [6, 8, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
|
56 |
+
# fmt: on
|
57 |
+
_C.DECODER.lower_face3d_indices = [0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14]
|
58 |
+
|
59 |
+
_C.INFERENCE = CN()
|
60 |
+
_C.INFERENCE.CHECKPOINT = ""
|
61 |
+
|
62 |
+
|
63 |
+
def get_cfg_defaults():
|
64 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
65 |
+
return _C.clone()
|
configs/renderer_conf.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
common:
|
2 |
+
descriptor_nc: 256
|
3 |
+
image_nc: 3
|
4 |
+
max_nc: 256
|
5 |
+
use_spect: false
|
6 |
+
editing_net:
|
7 |
+
base_nc: 64
|
8 |
+
layer: 3
|
9 |
+
num_res_blocks: 2
|
10 |
+
mapping_net:
|
11 |
+
coeff_nc: 73
|
12 |
+
descriptor_nc: 256
|
13 |
+
layer: 3
|
14 |
+
warpping_net:
|
15 |
+
base_nc: 32
|
16 |
+
decoder_layer: 3
|
17 |
+
encoder_layer: 5
|
core/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (6.26 kB). View file
|
|
core/networks/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.networks.generator import ContentEncoder, StyleEncoder, Decoder
|
2 |
+
from core.networks.disentangle_decoder import DisentangleDecoder
|
3 |
+
|
4 |
+
def get_network(name: str):
|
5 |
+
obj = globals().get(name)
|
6 |
+
if obj is None:
|
7 |
+
raise KeyError("Unknown Network: %s" % name)
|
8 |
+
else:
|
9 |
+
return obj
|
core/networks/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (533 Bytes). View file
|
|
core/networks/__pycache__/disentangle_decoder.cpython-37.pyc
ADDED
Binary file (3.17 kB). View file
|
|
core/networks/__pycache__/dynamic_conv.cpython-37.pyc
ADDED
Binary file (3.78 kB). View file
|
|
core/networks/__pycache__/dynamic_fc_decoder.cpython-37.pyc
ADDED
Binary file (3.16 kB). View file
|
|
core/networks/__pycache__/dynamic_linear.cpython-37.pyc
ADDED
Binary file (1.3 kB). View file
|
|
core/networks/__pycache__/generator.cpython-37.pyc
ADDED
Binary file (4.85 kB). View file
|
|
core/networks/__pycache__/mish.cpython-37.pyc
ADDED
Binary file (1.7 kB). View file
|
|
core/networks/__pycache__/self_attention_pooling.cpython-37.pyc
ADDED
Binary file (1.61 kB). View file
|
|
core/networks/__pycache__/styletalk.cpython-37.pyc
ADDED
Binary file (998 Bytes). View file
|
|
core/networks/__pycache__/transformer.cpython-37.pyc
ADDED
Binary file (9.3 kB). View file
|
|
core/networks/building_blocks.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class ADAIN(nn.Module):
|
5 |
+
def __init__(self, content_nc, condition_nc, hidden_nc):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
self.param_free_norm = nn.InstanceNorm1d(content_nc, affine=False)
|
9 |
+
|
10 |
+
use_bias = True
|
11 |
+
|
12 |
+
self.mlp_shared = nn.Sequential(
|
13 |
+
nn.Linear(condition_nc, hidden_nc, bias=use_bias),
|
14 |
+
nn.ReLU(),
|
15 |
+
)
|
16 |
+
self.mlp_gamma = nn.Linear(hidden_nc, content_nc, bias=use_bias)
|
17 |
+
self.mlp_beta = nn.Linear(hidden_nc, content_nc, bias=use_bias)
|
18 |
+
|
19 |
+
def forward(self, content, condition):
|
20 |
+
|
21 |
+
# Part 1. generate parameter-free normalized activations
|
22 |
+
normalized = self.param_free_norm(content)
|
23 |
+
|
24 |
+
# Part 2. produce scaling and bias conditioned on feature
|
25 |
+
actv = self.mlp_shared(condition)
|
26 |
+
gamma = self.mlp_gamma(actv)
|
27 |
+
beta = self.mlp_beta(actv)
|
28 |
+
|
29 |
+
# apply scale and bias
|
30 |
+
gamma = gamma.unsqueeze(-1)
|
31 |
+
beta = beta.unsqueeze(-1)
|
32 |
+
out = normalized * (1 + gamma) + beta
|
33 |
+
return out
|
34 |
+
|
35 |
+
|
36 |
+
class ConvNormRelu(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
conv_type="1d",
|
40 |
+
in_channels=3,
|
41 |
+
out_channels=64,
|
42 |
+
downsample=False,
|
43 |
+
kernel_size=None,
|
44 |
+
stride=None,
|
45 |
+
padding=None,
|
46 |
+
norm="IN",
|
47 |
+
leaky=False,
|
48 |
+
adain_condition_nc=None,
|
49 |
+
adain_hidden_nc=None,
|
50 |
+
):
|
51 |
+
super().__init__()
|
52 |
+
if kernel_size is None:
|
53 |
+
if downsample:
|
54 |
+
kernel_size, stride, padding = 4, 2, 1
|
55 |
+
else:
|
56 |
+
kernel_size, stride, padding = 3, 1, 1
|
57 |
+
|
58 |
+
if conv_type == "1d":
|
59 |
+
self.conv = nn.Conv1d(
|
60 |
+
in_channels,
|
61 |
+
out_channels,
|
62 |
+
kernel_size,
|
63 |
+
stride,
|
64 |
+
padding,
|
65 |
+
bias=False,
|
66 |
+
)
|
67 |
+
if norm == "IN":
|
68 |
+
self.norm = nn.InstanceNorm1d(out_channels, affine=True)
|
69 |
+
elif norm == "ADAIN":
|
70 |
+
self.norm = ADAIN(out_channels, adain_condition_nc, adain_hidden_nc)
|
71 |
+
elif norm == "NONE":
|
72 |
+
self.norm = nn.Identity()
|
73 |
+
else:
|
74 |
+
raise NotImplementedError
|
75 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
76 |
+
|
77 |
+
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) if leaky else nn.ReLU(inplace=True)
|
78 |
+
|
79 |
+
def forward(self, x, condition=None):
|
80 |
+
"""
|
81 |
+
|
82 |
+
Args:
|
83 |
+
x (_type_): (B, C, L)
|
84 |
+
condition (_type_, optional): (B, C)
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
_type_: _description_
|
88 |
+
"""
|
89 |
+
x = self.conv(x)
|
90 |
+
if isinstance(self.norm, ADAIN):
|
91 |
+
x = self.norm(x, condition)
|
92 |
+
else:
|
93 |
+
x = self.norm(x)
|
94 |
+
x = self.act(x)
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class MyConv1d(nn.Module):
|
99 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
100 |
+
super().__init__(*args, **kwargs)
|
101 |
+
self.conv_block = nn.Sequential(
|
102 |
+
nn.Conv1d(cin, cout, kernel_size, stride, padding),
|
103 |
+
nn.BatchNorm1d(cout),
|
104 |
+
)
|
105 |
+
self.act = nn.ReLU()
|
106 |
+
self.residual = residual
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
out = self.conv_block(x)
|
110 |
+
if self.residual:
|
111 |
+
out += x
|
112 |
+
return self.act(out)
|
core/networks/disentangle_decoder.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from .transformer import (
|
5 |
+
PositionalEncoding,
|
6 |
+
TransformerDecoderLayer,
|
7 |
+
TransformerDecoder,
|
8 |
+
)
|
9 |
+
from core.networks.dynamic_fc_decoder import DynamicFCDecoderLayer, DynamicFCDecoder
|
10 |
+
from core.utils import _reset_parameters
|
11 |
+
|
12 |
+
|
13 |
+
def get_decoder_network(
|
14 |
+
network_type,
|
15 |
+
d_model,
|
16 |
+
nhead,
|
17 |
+
dim_feedforward,
|
18 |
+
dropout,
|
19 |
+
activation,
|
20 |
+
normalize_before,
|
21 |
+
num_decoder_layers,
|
22 |
+
return_intermediate_dec,
|
23 |
+
dynamic_K,
|
24 |
+
dynamic_ratio,
|
25 |
+
):
|
26 |
+
decoder = None
|
27 |
+
if network_type == "TransformerDecoder":
|
28 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
|
29 |
+
norm = nn.LayerNorm(d_model)
|
30 |
+
decoder = TransformerDecoder(
|
31 |
+
decoder_layer,
|
32 |
+
num_decoder_layers,
|
33 |
+
norm,
|
34 |
+
return_intermediate_dec,
|
35 |
+
)
|
36 |
+
elif network_type == "DynamicFCDecoder":
|
37 |
+
d_style = d_model
|
38 |
+
decoder_layer = DynamicFCDecoderLayer(
|
39 |
+
d_model,
|
40 |
+
nhead,
|
41 |
+
d_style,
|
42 |
+
dynamic_K,
|
43 |
+
dynamic_ratio,
|
44 |
+
dim_feedforward,
|
45 |
+
dropout,
|
46 |
+
activation,
|
47 |
+
normalize_before,
|
48 |
+
)
|
49 |
+
norm = nn.LayerNorm(d_model)
|
50 |
+
decoder = DynamicFCDecoder(decoder_layer, num_decoder_layers, norm, return_intermediate_dec)
|
51 |
+
else:
|
52 |
+
raise ValueError(f"Invalid network_type {network_type}")
|
53 |
+
|
54 |
+
return decoder
|
55 |
+
|
56 |
+
|
57 |
+
class DisentangleDecoder(nn.Module):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
d_model=512,
|
61 |
+
nhead=8,
|
62 |
+
num_decoder_layers=3,
|
63 |
+
dim_feedforward=2048,
|
64 |
+
dropout=0.1,
|
65 |
+
activation="relu",
|
66 |
+
normalize_before=False,
|
67 |
+
return_intermediate_dec=False,
|
68 |
+
pos_embed_len=80,
|
69 |
+
upper_face3d_indices=tuple(list(range(19)) + list(range(46, 51))),
|
70 |
+
lower_face3d_indices=tuple(range(19, 46)),
|
71 |
+
network_type="None",
|
72 |
+
dynamic_K=None,
|
73 |
+
dynamic_ratio=None,
|
74 |
+
**_,
|
75 |
+
) -> None:
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
self.upper_face3d_indices = upper_face3d_indices
|
79 |
+
self.lower_face3d_indices = lower_face3d_indices
|
80 |
+
|
81 |
+
# upper_decoder_layer = TransformerDecoderLayer(
|
82 |
+
# d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
83 |
+
# )
|
84 |
+
# upper_decoder_norm = nn.LayerNorm(d_model)
|
85 |
+
# self.upper_decoder = TransformerDecoder(
|
86 |
+
# upper_decoder_layer,
|
87 |
+
# num_decoder_layers,
|
88 |
+
# upper_decoder_norm,
|
89 |
+
# return_intermediate=return_intermediate_dec,
|
90 |
+
# )
|
91 |
+
self.upper_decoder = get_decoder_network(
|
92 |
+
network_type,
|
93 |
+
d_model,
|
94 |
+
nhead,
|
95 |
+
dim_feedforward,
|
96 |
+
dropout,
|
97 |
+
activation,
|
98 |
+
normalize_before,
|
99 |
+
num_decoder_layers,
|
100 |
+
return_intermediate_dec,
|
101 |
+
dynamic_K,
|
102 |
+
dynamic_ratio,
|
103 |
+
)
|
104 |
+
_reset_parameters(self.upper_decoder)
|
105 |
+
|
106 |
+
# lower_decoder_layer = TransformerDecoderLayer(
|
107 |
+
# d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
108 |
+
# )
|
109 |
+
# lower_decoder_norm = nn.LayerNorm(d_model)
|
110 |
+
# self.lower_decoder = TransformerDecoder(
|
111 |
+
# lower_decoder_layer,
|
112 |
+
# num_decoder_layers,
|
113 |
+
# lower_decoder_norm,
|
114 |
+
# return_intermediate=return_intermediate_dec,
|
115 |
+
# )
|
116 |
+
self.lower_decoder = get_decoder_network(
|
117 |
+
network_type,
|
118 |
+
d_model,
|
119 |
+
nhead,
|
120 |
+
dim_feedforward,
|
121 |
+
dropout,
|
122 |
+
activation,
|
123 |
+
normalize_before,
|
124 |
+
num_decoder_layers,
|
125 |
+
return_intermediate_dec,
|
126 |
+
dynamic_K,
|
127 |
+
dynamic_ratio,
|
128 |
+
)
|
129 |
+
_reset_parameters(self.lower_decoder)
|
130 |
+
|
131 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
132 |
+
|
133 |
+
tail_hidden_dim = d_model // 2
|
134 |
+
self.upper_tail_fc = nn.Sequential(
|
135 |
+
nn.Linear(d_model, tail_hidden_dim),
|
136 |
+
nn.ReLU(),
|
137 |
+
nn.Linear(tail_hidden_dim, tail_hidden_dim),
|
138 |
+
nn.ReLU(),
|
139 |
+
nn.Linear(tail_hidden_dim, len(upper_face3d_indices)),
|
140 |
+
)
|
141 |
+
self.lower_tail_fc = nn.Sequential(
|
142 |
+
nn.Linear(d_model, tail_hidden_dim),
|
143 |
+
nn.ReLU(),
|
144 |
+
nn.Linear(tail_hidden_dim, tail_hidden_dim),
|
145 |
+
nn.ReLU(),
|
146 |
+
nn.Linear(tail_hidden_dim, len(lower_face3d_indices)),
|
147 |
+
)
|
148 |
+
|
149 |
+
def forward(self, content, style_code):
|
150 |
+
"""
|
151 |
+
|
152 |
+
Args:
|
153 |
+
content (_type_): (B, num_frames, window, C_dmodel)
|
154 |
+
style_code (_type_): (B, C_dmodel)
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
face3d: (B, L_clip, C_3dmm)
|
158 |
+
"""
|
159 |
+
B, N, W, C = content.shape
|
160 |
+
style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
|
161 |
+
style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
162 |
+
# (W, B*N, C)
|
163 |
+
|
164 |
+
content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
165 |
+
# (W, B*N, C)
|
166 |
+
tgt = torch.zeros_like(style)
|
167 |
+
pos_embed = self.pos_embed(W)
|
168 |
+
pos_embed = pos_embed.permute(1, 0, 2)
|
169 |
+
|
170 |
+
upper_face3d_feat = self.upper_decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
|
171 |
+
# (W, B*N, C)
|
172 |
+
upper_face3d_feat = upper_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
|
173 |
+
# (B, N, C)
|
174 |
+
upper_face3d = self.upper_tail_fc(upper_face3d_feat)
|
175 |
+
# (B, N, C_exp)
|
176 |
+
|
177 |
+
lower_face3d_feat = self.lower_decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
|
178 |
+
lower_face3d_feat = lower_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
|
179 |
+
lower_face3d = self.lower_tail_fc(lower_face3d_feat)
|
180 |
+
C_exp = len(self.upper_face3d_indices) + len(self.lower_face3d_indices)
|
181 |
+
face3d = torch.zeros(B, N, C_exp).to(upper_face3d)
|
182 |
+
face3d[:, :, self.upper_face3d_indices] = upper_face3d
|
183 |
+
face3d[:, :, self.lower_face3d_indices] = lower_face3d
|
184 |
+
return face3d
|
core/networks/dynamic_conv.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class Attention(nn.Module):
|
9 |
+
def __init__(self, cond_planes, ratio, K, temperature=30, init_weight=True):
|
10 |
+
super().__init__()
|
11 |
+
# self.avgpool = nn.AdaptiveAvgPool2d(1)
|
12 |
+
self.temprature = temperature
|
13 |
+
assert cond_planes > ratio
|
14 |
+
hidden_planes = cond_planes // ratio
|
15 |
+
self.net = nn.Sequential(
|
16 |
+
nn.Conv2d(cond_planes, hidden_planes, kernel_size=1, bias=False),
|
17 |
+
nn.ReLU(),
|
18 |
+
nn.Conv2d(hidden_planes, K, kernel_size=1, bias=False),
|
19 |
+
)
|
20 |
+
|
21 |
+
if init_weight:
|
22 |
+
self._initialize_weights()
|
23 |
+
|
24 |
+
def update_temprature(self):
|
25 |
+
if self.temprature > 1:
|
26 |
+
self.temprature -= 1
|
27 |
+
|
28 |
+
def _initialize_weights(self):
|
29 |
+
for m in self.modules():
|
30 |
+
if isinstance(m, nn.Conv2d):
|
31 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
32 |
+
if m.bias is not None:
|
33 |
+
nn.init.constant_(m.bias, 0)
|
34 |
+
if isinstance(m, nn.BatchNorm2d):
|
35 |
+
nn.init.constant_(m.weight, 1)
|
36 |
+
nn.init.constant_(m.bias, 0)
|
37 |
+
|
38 |
+
def forward(self, cond):
|
39 |
+
"""
|
40 |
+
|
41 |
+
Args:
|
42 |
+
cond (_type_): (B, C_style)
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
_type_: (B, K)
|
46 |
+
"""
|
47 |
+
|
48 |
+
# att = self.avgpool(cond) # bs,dim,1,1
|
49 |
+
att = cond.view(cond.shape[0], cond.shape[1], 1, 1)
|
50 |
+
att = self.net(att).view(cond.shape[0], -1) # bs,K
|
51 |
+
return F.softmax(att / self.temprature, -1)
|
52 |
+
|
53 |
+
|
54 |
+
class DynamicConv(nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_planes,
|
58 |
+
out_planes,
|
59 |
+
cond_planes,
|
60 |
+
kernel_size,
|
61 |
+
stride,
|
62 |
+
padding=0,
|
63 |
+
dilation=1,
|
64 |
+
groups=1,
|
65 |
+
bias=True,
|
66 |
+
K=4,
|
67 |
+
temperature=30,
|
68 |
+
ratio=4,
|
69 |
+
init_weight=True,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
self.in_planes = in_planes
|
73 |
+
self.out_planes = out_planes
|
74 |
+
self.cond_planes = cond_planes
|
75 |
+
self.kernel_size = kernel_size
|
76 |
+
self.stride = stride
|
77 |
+
self.padding = padding
|
78 |
+
self.dilation = dilation
|
79 |
+
self.groups = groups
|
80 |
+
self.bias = bias
|
81 |
+
self.K = K
|
82 |
+
self.init_weight = init_weight
|
83 |
+
self.attention = Attention(
|
84 |
+
cond_planes=cond_planes, ratio=ratio, K=K, temperature=temperature, init_weight=init_weight
|
85 |
+
)
|
86 |
+
|
87 |
+
self.weight = nn.Parameter(
|
88 |
+
torch.randn(K, out_planes, in_planes // groups, kernel_size, kernel_size), requires_grad=True
|
89 |
+
)
|
90 |
+
if bias:
|
91 |
+
self.bias = nn.Parameter(torch.randn(K, out_planes), requires_grad=True)
|
92 |
+
else:
|
93 |
+
self.bias = None
|
94 |
+
|
95 |
+
if self.init_weight:
|
96 |
+
self._initialize_weights()
|
97 |
+
|
98 |
+
def _initialize_weights(self):
|
99 |
+
for i in range(self.K):
|
100 |
+
nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
|
101 |
+
if self.bias is not None:
|
102 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
|
103 |
+
if fan_in != 0:
|
104 |
+
bound = 1 / math.sqrt(fan_in)
|
105 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
106 |
+
|
107 |
+
def forward(self, x, cond):
|
108 |
+
"""
|
109 |
+
|
110 |
+
Args:
|
111 |
+
x (_type_): (B, C_in, L, 1)
|
112 |
+
cond (_type_): (B, C_style)
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
_type_: (B, C_out, L, 1)
|
116 |
+
"""
|
117 |
+
bs, in_planels, h, w = x.shape
|
118 |
+
softmax_att = self.attention(cond) # bs,K
|
119 |
+
x = x.view(1, -1, h, w)
|
120 |
+
weight = self.weight.view(self.K, -1) # K,-1
|
121 |
+
aggregate_weight = torch.mm(softmax_att, weight).view(
|
122 |
+
bs * self.out_planes, self.in_planes // self.groups, self.kernel_size, self.kernel_size
|
123 |
+
) # bs*out_p,in_p,k,k
|
124 |
+
|
125 |
+
if self.bias is not None:
|
126 |
+
bias = self.bias.view(self.K, -1) # K,out_p
|
127 |
+
aggregate_bias = torch.mm(softmax_att, bias).view(-1) # bs*out_p
|
128 |
+
output = F.conv2d(
|
129 |
+
x, # 1, bs*in_p, L, 1
|
130 |
+
weight=aggregate_weight,
|
131 |
+
bias=aggregate_bias,
|
132 |
+
stride=self.stride,
|
133 |
+
padding=self.padding,
|
134 |
+
groups=self.groups * bs,
|
135 |
+
dilation=self.dilation,
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
output = F.conv2d(
|
139 |
+
x,
|
140 |
+
weight=aggregate_weight,
|
141 |
+
bias=None,
|
142 |
+
stride=self.stride,
|
143 |
+
padding=self.padding,
|
144 |
+
groups=self.groups * bs,
|
145 |
+
dilation=self.dilation,
|
146 |
+
)
|
147 |
+
|
148 |
+
output = output.view(bs, self.out_planes, h, w)
|
149 |
+
return output
|
core/networks/dynamic_fc_decoder.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from core.networks.transformer import _get_activation_fn, _get_clones
|
5 |
+
from core.networks.dynamic_linear import DynamicLinear
|
6 |
+
|
7 |
+
|
8 |
+
class DynamicFCDecoderLayer(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
d_model,
|
12 |
+
nhead,
|
13 |
+
d_style,
|
14 |
+
dynamic_K,
|
15 |
+
dynamic_ratio,
|
16 |
+
dim_feedforward=2048,
|
17 |
+
dropout=0.1,
|
18 |
+
activation="relu",
|
19 |
+
normalize_before=False,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
23 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
24 |
+
# Implementation of Feedforward model
|
25 |
+
# self.linear1 = nn.Linear(d_model, dim_feedforward)
|
26 |
+
self.linear1 = DynamicLinear(d_model, dim_feedforward, d_style, K=dynamic_K, ratio=dynamic_ratio)
|
27 |
+
self.dropout = nn.Dropout(dropout)
|
28 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
29 |
+
# self.linear2 = DynamicLinear(dim_feedforward, d_model, d_style, K=dynamic_K, ratio=dynamic_ratio)
|
30 |
+
|
31 |
+
self.norm1 = nn.LayerNorm(d_model)
|
32 |
+
self.norm2 = nn.LayerNorm(d_model)
|
33 |
+
self.norm3 = nn.LayerNorm(d_model)
|
34 |
+
self.dropout1 = nn.Dropout(dropout)
|
35 |
+
self.dropout2 = nn.Dropout(dropout)
|
36 |
+
self.dropout3 = nn.Dropout(dropout)
|
37 |
+
|
38 |
+
self.activation = _get_activation_fn(activation)
|
39 |
+
self.normalize_before = normalize_before
|
40 |
+
|
41 |
+
def with_pos_embed(self, tensor, pos):
|
42 |
+
return tensor if pos is None else tensor + pos
|
43 |
+
|
44 |
+
def forward_post(
|
45 |
+
self,
|
46 |
+
tgt,
|
47 |
+
memory,
|
48 |
+
style,
|
49 |
+
tgt_mask=None,
|
50 |
+
memory_mask=None,
|
51 |
+
tgt_key_padding_mask=None,
|
52 |
+
memory_key_padding_mask=None,
|
53 |
+
pos=None,
|
54 |
+
query_pos=None,
|
55 |
+
):
|
56 |
+
# q = k = self.with_pos_embed(tgt, query_pos)
|
57 |
+
tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
58 |
+
tgt = tgt + self.dropout1(tgt2)
|
59 |
+
tgt = self.norm1(tgt)
|
60 |
+
tgt2 = self.multihead_attn(
|
61 |
+
query=tgt, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
|
62 |
+
)[0]
|
63 |
+
tgt = tgt + self.dropout2(tgt2)
|
64 |
+
tgt = self.norm2(tgt)
|
65 |
+
# tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))), style)
|
66 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))))
|
67 |
+
tgt = tgt + self.dropout3(tgt2)
|
68 |
+
tgt = self.norm3(tgt)
|
69 |
+
return tgt
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self,
|
73 |
+
tgt,
|
74 |
+
memory,
|
75 |
+
style,
|
76 |
+
tgt_mask=None,
|
77 |
+
memory_mask=None,
|
78 |
+
tgt_key_padding_mask=None,
|
79 |
+
memory_key_padding_mask=None,
|
80 |
+
pos=None,
|
81 |
+
query_pos=None,
|
82 |
+
):
|
83 |
+
if self.normalize_before:
|
84 |
+
raise NotImplementedError
|
85 |
+
|
86 |
+
return self.forward_post(
|
87 |
+
tgt, memory, style, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
class DynamicFCDecoder(nn.Module):
|
92 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
93 |
+
super().__init__()
|
94 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
95 |
+
self.num_layers = num_layers
|
96 |
+
self.norm = norm
|
97 |
+
self.return_intermediate = return_intermediate
|
98 |
+
|
99 |
+
def forward(
|
100 |
+
self,
|
101 |
+
tgt,
|
102 |
+
memory,
|
103 |
+
tgt_mask=None,
|
104 |
+
memory_mask=None,
|
105 |
+
tgt_key_padding_mask=None,
|
106 |
+
memory_key_padding_mask=None,
|
107 |
+
pos=None,
|
108 |
+
query_pos=None,
|
109 |
+
):
|
110 |
+
style = query_pos[0]
|
111 |
+
# (B*N, C)
|
112 |
+
output = tgt + pos + query_pos
|
113 |
+
|
114 |
+
intermediate = []
|
115 |
+
|
116 |
+
for layer in self.layers:
|
117 |
+
output = layer(
|
118 |
+
output,
|
119 |
+
memory,
|
120 |
+
style,
|
121 |
+
tgt_mask=tgt_mask,
|
122 |
+
memory_mask=memory_mask,
|
123 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
124 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
125 |
+
pos=pos,
|
126 |
+
query_pos=query_pos,
|
127 |
+
)
|
128 |
+
if self.return_intermediate:
|
129 |
+
intermediate.append(self.norm(output))
|
130 |
+
|
131 |
+
if self.norm is not None:
|
132 |
+
output = self.norm(output)
|
133 |
+
if self.return_intermediate:
|
134 |
+
intermediate.pop()
|
135 |
+
intermediate.append(output)
|
136 |
+
|
137 |
+
if self.return_intermediate:
|
138 |
+
return torch.stack(intermediate)
|
139 |
+
|
140 |
+
return output.unsqueeze(0)
|
core/networks/dynamic_linear.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from core.networks.dynamic_conv import DynamicConv
|
8 |
+
|
9 |
+
|
10 |
+
class DynamicLinear(nn.Module):
|
11 |
+
def __init__(self, in_planes, out_planes, cond_planes, bias=True, K=4, temperature=30, ratio=4, init_weight=True):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.dynamic_conv = DynamicConv(
|
15 |
+
in_planes,
|
16 |
+
out_planes,
|
17 |
+
cond_planes,
|
18 |
+
kernel_size=1,
|
19 |
+
stride=1,
|
20 |
+
padding=0,
|
21 |
+
bias=bias,
|
22 |
+
K=K,
|
23 |
+
ratio=ratio,
|
24 |
+
temperature=temperature,
|
25 |
+
init_weight=init_weight,
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, x, cond):
|
29 |
+
"""
|
30 |
+
|
31 |
+
Args:
|
32 |
+
x (_type_): (L, B, C_in)
|
33 |
+
cond (_type_): (B, C_style)
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
_type_: (L, B, C_out)
|
37 |
+
"""
|
38 |
+
x = x.permute(1, 2, 0).unsqueeze(-1)
|
39 |
+
out = self.dynamic_conv(x, cond)
|
40 |
+
# (B, C_out, L, 1)
|
41 |
+
out = out.squeeze().permute(2, 0, 1)
|
42 |
+
return out
|
core/networks/generator.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from .transformer import (
|
5 |
+
TransformerEncoder,
|
6 |
+
TransformerEncoderLayer,
|
7 |
+
PositionalEncoding,
|
8 |
+
TransformerDecoderLayer,
|
9 |
+
TransformerDecoder,
|
10 |
+
)
|
11 |
+
from core.utils import _reset_parameters
|
12 |
+
from core.networks.self_attention_pooling import SelfAttentionPooling
|
13 |
+
|
14 |
+
|
15 |
+
class ContentEncoder(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
d_model=512,
|
19 |
+
nhead=8,
|
20 |
+
num_encoder_layers=6,
|
21 |
+
dim_feedforward=2048,
|
22 |
+
dropout=0.1,
|
23 |
+
activation="relu",
|
24 |
+
normalize_before=False,
|
25 |
+
pos_embed_len=80,
|
26 |
+
ph_embed_dim=128,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
|
31 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
32 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
33 |
+
|
34 |
+
_reset_parameters(self.encoder)
|
35 |
+
|
36 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
37 |
+
|
38 |
+
self.ph_embedding = nn.Embedding(41, ph_embed_dim)
|
39 |
+
self.increase_embed_dim = nn.Linear(ph_embed_dim, d_model)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
"""
|
43 |
+
|
44 |
+
Args:
|
45 |
+
x (_type_): (B, num_frames, window)
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
content: (B, num_frames, window, C_dmodel)
|
49 |
+
"""
|
50 |
+
x_embedding = self.ph_embedding(x)
|
51 |
+
x_embedding = self.increase_embed_dim(x_embedding)
|
52 |
+
# (B, N, W, C)
|
53 |
+
B, N, W, C = x_embedding.shape
|
54 |
+
x_embedding = x_embedding.reshape(B * N, W, C)
|
55 |
+
x_embedding = x_embedding.permute(1, 0, 2)
|
56 |
+
# (W, B*N, C)
|
57 |
+
|
58 |
+
pos = self.pos_embed(W)
|
59 |
+
pos = pos.permute(1, 0, 2)
|
60 |
+
# (W, 1, C)
|
61 |
+
|
62 |
+
content = self.encoder(x_embedding, pos=pos)
|
63 |
+
# (W, B*N, C)
|
64 |
+
content = content.permute(1, 0, 2).reshape(B, N, W, C)
|
65 |
+
# (B, N, W, C)
|
66 |
+
|
67 |
+
return content
|
68 |
+
|
69 |
+
|
70 |
+
class StyleEncoder(nn.Module):
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
d_model=512,
|
74 |
+
nhead=8,
|
75 |
+
num_encoder_layers=6,
|
76 |
+
dim_feedforward=2048,
|
77 |
+
dropout=0.1,
|
78 |
+
activation="relu",
|
79 |
+
normalize_before=False,
|
80 |
+
pos_embed_len=80,
|
81 |
+
input_dim=128,
|
82 |
+
aggregate_method="average",
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
|
86 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
87 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
88 |
+
_reset_parameters(self.encoder)
|
89 |
+
|
90 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
91 |
+
|
92 |
+
self.increase_embed_dim = nn.Linear(input_dim, d_model)
|
93 |
+
|
94 |
+
self.aggregate_method = None
|
95 |
+
if aggregate_method == "self_attention_pooling":
|
96 |
+
self.aggregate_method = SelfAttentionPooling(d_model)
|
97 |
+
elif aggregate_method == "average":
|
98 |
+
pass
|
99 |
+
else:
|
100 |
+
raise ValueError(f"Invalid aggregate method {aggregate_method}")
|
101 |
+
|
102 |
+
def forward(self, x, pad_mask=None):
|
103 |
+
"""
|
104 |
+
|
105 |
+
Args:
|
106 |
+
x (_type_): (B, num_frames(L), C_exp)
|
107 |
+
pad_mask: (B, num_frames)
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
style_code: (B, C_model)
|
111 |
+
"""
|
112 |
+
x = self.increase_embed_dim(x)
|
113 |
+
# (B, L, C)
|
114 |
+
x = x.permute(1, 0, 2)
|
115 |
+
# (L, B, C)
|
116 |
+
|
117 |
+
pos = self.pos_embed(x.shape[0])
|
118 |
+
pos = pos.permute(1, 0, 2)
|
119 |
+
# (L, 1, C)
|
120 |
+
|
121 |
+
style = self.encoder(x, pos=pos, src_key_padding_mask=pad_mask)
|
122 |
+
# (L, B, C)
|
123 |
+
|
124 |
+
if self.aggregate_method is not None:
|
125 |
+
permute_style = style.permute(1, 0, 2)
|
126 |
+
# (B, L, C)
|
127 |
+
style_code = self.aggregate_method(permute_style, pad_mask)
|
128 |
+
return style_code
|
129 |
+
|
130 |
+
if pad_mask is None:
|
131 |
+
style = style.permute(1, 2, 0)
|
132 |
+
# (B, C, L)
|
133 |
+
style_code = style.mean(2)
|
134 |
+
# (B, C)
|
135 |
+
else:
|
136 |
+
permute_style = style.permute(1, 0, 2)
|
137 |
+
# (B, L, C)
|
138 |
+
permute_style[pad_mask] = 0
|
139 |
+
sum_style_code = permute_style.sum(dim=1)
|
140 |
+
# (B, C)
|
141 |
+
valid_token_num = (~pad_mask).sum(dim=1).unsqueeze(-1)
|
142 |
+
# (B, 1)
|
143 |
+
style_code = sum_style_code / valid_token_num
|
144 |
+
# (B, C)
|
145 |
+
|
146 |
+
return style_code
|
147 |
+
|
148 |
+
|
149 |
+
class Decoder(nn.Module):
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
d_model=512,
|
153 |
+
nhead=8,
|
154 |
+
num_decoder_layers=3,
|
155 |
+
dim_feedforward=2048,
|
156 |
+
dropout=0.1,
|
157 |
+
activation="relu",
|
158 |
+
normalize_before=False,
|
159 |
+
return_intermediate_dec=False,
|
160 |
+
pos_embed_len=80,
|
161 |
+
output_dim=64,
|
162 |
+
**_,
|
163 |
+
) -> None:
|
164 |
+
super().__init__()
|
165 |
+
|
166 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
|
167 |
+
decoder_norm = nn.LayerNorm(d_model)
|
168 |
+
self.decoder = TransformerDecoder(
|
169 |
+
decoder_layer,
|
170 |
+
num_decoder_layers,
|
171 |
+
decoder_norm,
|
172 |
+
return_intermediate=return_intermediate_dec,
|
173 |
+
)
|
174 |
+
_reset_parameters(self.decoder)
|
175 |
+
|
176 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
177 |
+
|
178 |
+
tail_hidden_dim = d_model // 2
|
179 |
+
self.tail_fc = nn.Sequential(
|
180 |
+
nn.Linear(d_model, tail_hidden_dim),
|
181 |
+
nn.ReLU(),
|
182 |
+
nn.Linear(tail_hidden_dim, tail_hidden_dim),
|
183 |
+
nn.ReLU(),
|
184 |
+
nn.Linear(tail_hidden_dim, output_dim),
|
185 |
+
)
|
186 |
+
|
187 |
+
def forward(self, content, style_code):
|
188 |
+
"""
|
189 |
+
|
190 |
+
Args:
|
191 |
+
content (_type_): (B, num_frames, window, C_dmodel)
|
192 |
+
style_code (_type_): (B, C_dmodel)
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
face3d: (B, num_frames, C_3dmm)
|
196 |
+
"""
|
197 |
+
B, N, W, C = content.shape
|
198 |
+
style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
|
199 |
+
style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
200 |
+
# (W, B*N, C)
|
201 |
+
|
202 |
+
content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
203 |
+
# (W, B*N, C)
|
204 |
+
tgt = torch.zeros_like(style)
|
205 |
+
pos_embed = self.pos_embed(W)
|
206 |
+
pos_embed = pos_embed.permute(1, 0, 2)
|
207 |
+
face3d_feat = self.decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
|
208 |
+
# (W, B*N, C)
|
209 |
+
face3d_feat = face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
|
210 |
+
# (B, N, C)
|
211 |
+
face3d = self.tail_fc(face3d_feat)
|
212 |
+
# (B, N, C_exp)
|
213 |
+
return face3d
|
core/networks/mish.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Applies the mish function element-wise:
|
3 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
4 |
+
"""
|
5 |
+
|
6 |
+
# import pytorch
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
@torch.jit.script
|
12 |
+
def mish(input):
|
13 |
+
"""
|
14 |
+
Applies the mish function element-wise:
|
15 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
16 |
+
See additional documentation for mish class.
|
17 |
+
"""
|
18 |
+
return input * torch.tanh(F.softplus(input))
|
19 |
+
|
20 |
+
class Mish(nn.Module):
|
21 |
+
"""
|
22 |
+
Applies the mish function element-wise:
|
23 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
24 |
+
|
25 |
+
Shape:
|
26 |
+
- Input: (N, *) where * means, any number of additional
|
27 |
+
dimensions
|
28 |
+
- Output: (N, *), same shape as the input
|
29 |
+
|
30 |
+
Examples:
|
31 |
+
>>> m = Mish()
|
32 |
+
>>> input = torch.randn(2)
|
33 |
+
>>> output = m(input)
|
34 |
+
|
35 |
+
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self):
|
39 |
+
"""
|
40 |
+
Init method.
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
def forward(self, input):
|
45 |
+
"""
|
46 |
+
Forward pass of the function.
|
47 |
+
"""
|
48 |
+
if torch.__version__ >= "1.9":
|
49 |
+
return F.mish(input)
|
50 |
+
else:
|
51 |
+
return mish(input)
|
core/networks/self_attention_pooling.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from core.networks.mish import Mish
|
4 |
+
|
5 |
+
|
6 |
+
class SelfAttentionPooling(nn.Module):
|
7 |
+
"""
|
8 |
+
Implementation of SelfAttentionPooling
|
9 |
+
Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
|
10 |
+
https://arxiv.org/pdf/2008.01077v1.pdf
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, input_dim):
|
14 |
+
super(SelfAttentionPooling, self).__init__()
|
15 |
+
self.W = nn.Sequential(nn.Linear(input_dim, input_dim), Mish(), nn.Linear(input_dim, 1))
|
16 |
+
self.softmax = nn.functional.softmax
|
17 |
+
|
18 |
+
def forward(self, batch_rep, att_mask=None):
|
19 |
+
"""
|
20 |
+
N: batch size, T: sequence length, H: Hidden dimension
|
21 |
+
input:
|
22 |
+
batch_rep : size (N, T, H)
|
23 |
+
attention_weight:
|
24 |
+
att_w : size (N, T, 1)
|
25 |
+
att_mask:
|
26 |
+
att_mask: size (N, T): if True, mask this item.
|
27 |
+
return:
|
28 |
+
utter_rep: size (N, H)
|
29 |
+
"""
|
30 |
+
|
31 |
+
att_logits = self.W(batch_rep).squeeze(-1)
|
32 |
+
# (N, T)
|
33 |
+
if att_mask is not None:
|
34 |
+
att_mask_logits = att_mask.to(dtype=batch_rep.dtype) * -100000.0
|
35 |
+
# (N, T)
|
36 |
+
att_logits = att_mask_logits + att_logits
|
37 |
+
|
38 |
+
att_w = self.softmax(att_logits, dim=-1).unsqueeze(-1)
|
39 |
+
utter_rep = torch.sum(batch_rep * att_w, dim=1)
|
40 |
+
|
41 |
+
return utter_rep
|
42 |
+
|
43 |
+
|
core/networks/styletalk.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from core.networks import get_network
|
4 |
+
|
5 |
+
|
6 |
+
class StyleTalk(nn.Module):
|
7 |
+
def __init__(self, cfg) -> None:
|
8 |
+
super().__init__()
|
9 |
+
self.cfg = cfg
|
10 |
+
|
11 |
+
content_encoder_class = get_network(cfg.CONTENT_ENCODER_TYPE)
|
12 |
+
self.content_encoder = content_encoder_class(**cfg.CONTENT_ENCODER)
|
13 |
+
|
14 |
+
style_encoder_class = get_network(cfg.STYLE_ENCODER_TYPE)
|
15 |
+
cfg.defrost()
|
16 |
+
cfg.STYLE_ENCODER.input_dim = cfg.DATASET.FACE3D_DIM
|
17 |
+
cfg.freeze()
|
18 |
+
self.style_encoder = style_encoder_class(**cfg.STYLE_ENCODER)
|
19 |
+
|
20 |
+
decoder_class = get_network(cfg.DECODER_TYPE)
|
21 |
+
cfg.defrost()
|
22 |
+
cfg.DECODER.output_dim = cfg.DATASET.FACE3D_DIM
|
23 |
+
cfg.freeze()
|
24 |
+
self.decoder = decoder_class(**cfg.DECODER)
|
core/networks/transformer.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import copy
|
6 |
+
|
7 |
+
|
8 |
+
class PositionalEncoding(nn.Module):
|
9 |
+
def __init__(self, d_hid, n_position=200):
|
10 |
+
super(PositionalEncoding, self).__init__()
|
11 |
+
|
12 |
+
# Not a parameter
|
13 |
+
self.register_buffer("pos_table", self._get_sinusoid_encoding_table(n_position, d_hid))
|
14 |
+
|
15 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
16 |
+
"""Sinusoid position encoding table"""
|
17 |
+
# TODO: make it with torch instead of numpy
|
18 |
+
|
19 |
+
def get_position_angle_vec(position):
|
20 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
21 |
+
|
22 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
23 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
24 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
25 |
+
|
26 |
+
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
27 |
+
|
28 |
+
def forward(self, winsize):
|
29 |
+
return self.pos_table[:, :winsize].clone().detach()
|
30 |
+
|
31 |
+
|
32 |
+
def _get_activation_fn(activation):
|
33 |
+
"""Return an activation function given a string"""
|
34 |
+
if activation == "relu":
|
35 |
+
return F.relu
|
36 |
+
if activation == "gelu":
|
37 |
+
return F.gelu
|
38 |
+
if activation == "glu":
|
39 |
+
return F.glu
|
40 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
41 |
+
|
42 |
+
|
43 |
+
def _get_clones(module, N):
|
44 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
45 |
+
|
46 |
+
|
47 |
+
class Transformer(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
d_model=512,
|
51 |
+
nhead=8,
|
52 |
+
num_encoder_layers=6,
|
53 |
+
num_decoder_layers=6,
|
54 |
+
dim_feedforward=2048,
|
55 |
+
dropout=0.1,
|
56 |
+
activation="relu",
|
57 |
+
normalize_before=False,
|
58 |
+
return_intermediate_dec=True,
|
59 |
+
):
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
|
63 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
64 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
65 |
+
|
66 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
|
67 |
+
decoder_norm = nn.LayerNorm(d_model)
|
68 |
+
self.decoder = TransformerDecoder(
|
69 |
+
decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
|
70 |
+
)
|
71 |
+
|
72 |
+
self._reset_parameters()
|
73 |
+
|
74 |
+
self.d_model = d_model
|
75 |
+
self.nhead = nhead
|
76 |
+
|
77 |
+
def _reset_parameters(self):
|
78 |
+
for p in self.parameters():
|
79 |
+
if p.dim() > 1:
|
80 |
+
nn.init.xavier_uniform_(p)
|
81 |
+
|
82 |
+
def forward(self, opt, src, query_embed, pos_embed):
|
83 |
+
# flatten NxCxHxW to HWxNxC
|
84 |
+
|
85 |
+
src = src.permute(1, 0, 2)
|
86 |
+
pos_embed = pos_embed.permute(1, 0, 2)
|
87 |
+
query_embed = query_embed.permute(1, 0, 2)
|
88 |
+
|
89 |
+
tgt = torch.zeros_like(query_embed)
|
90 |
+
memory = self.encoder(src, pos=pos_embed)
|
91 |
+
|
92 |
+
hs = self.decoder(tgt, memory, pos=pos_embed, query_pos=query_embed)
|
93 |
+
return hs
|
94 |
+
|
95 |
+
|
96 |
+
class TransformerEncoder(nn.Module):
|
97 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
98 |
+
super().__init__()
|
99 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
100 |
+
self.num_layers = num_layers
|
101 |
+
self.norm = norm
|
102 |
+
|
103 |
+
def forward(self, src, mask=None, src_key_padding_mask=None, pos=None):
|
104 |
+
output = src + pos
|
105 |
+
|
106 |
+
for layer in self.layers:
|
107 |
+
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
|
108 |
+
|
109 |
+
if self.norm is not None:
|
110 |
+
output = self.norm(output)
|
111 |
+
|
112 |
+
return output
|
113 |
+
|
114 |
+
|
115 |
+
class TransformerDecoder(nn.Module):
|
116 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
117 |
+
super().__init__()
|
118 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
119 |
+
self.num_layers = num_layers
|
120 |
+
self.norm = norm
|
121 |
+
self.return_intermediate = return_intermediate
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self,
|
125 |
+
tgt,
|
126 |
+
memory,
|
127 |
+
tgt_mask=None,
|
128 |
+
memory_mask=None,
|
129 |
+
tgt_key_padding_mask=None,
|
130 |
+
memory_key_padding_mask=None,
|
131 |
+
pos=None,
|
132 |
+
query_pos=None,
|
133 |
+
):
|
134 |
+
output = tgt + pos + query_pos
|
135 |
+
|
136 |
+
intermediate = []
|
137 |
+
|
138 |
+
for layer in self.layers:
|
139 |
+
output = layer(
|
140 |
+
output,
|
141 |
+
memory,
|
142 |
+
tgt_mask=tgt_mask,
|
143 |
+
memory_mask=memory_mask,
|
144 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
145 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
146 |
+
pos=pos,
|
147 |
+
query_pos=query_pos,
|
148 |
+
)
|
149 |
+
if self.return_intermediate:
|
150 |
+
intermediate.append(self.norm(output))
|
151 |
+
|
152 |
+
if self.norm is not None:
|
153 |
+
output = self.norm(output)
|
154 |
+
if self.return_intermediate:
|
155 |
+
intermediate.pop()
|
156 |
+
intermediate.append(output)
|
157 |
+
|
158 |
+
if self.return_intermediate:
|
159 |
+
return torch.stack(intermediate)
|
160 |
+
|
161 |
+
return output.unsqueeze(0)
|
162 |
+
|
163 |
+
|
164 |
+
class TransformerEncoderLayer(nn.Module):
|
165 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False):
|
166 |
+
super().__init__()
|
167 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
168 |
+
# Implementation of Feedforward model
|
169 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
170 |
+
self.dropout = nn.Dropout(dropout)
|
171 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
172 |
+
|
173 |
+
self.norm1 = nn.LayerNorm(d_model)
|
174 |
+
self.norm2 = nn.LayerNorm(d_model)
|
175 |
+
self.dropout1 = nn.Dropout(dropout)
|
176 |
+
self.dropout2 = nn.Dropout(dropout)
|
177 |
+
|
178 |
+
self.activation = _get_activation_fn(activation)
|
179 |
+
self.normalize_before = normalize_before
|
180 |
+
|
181 |
+
def with_pos_embed(self, tensor, pos):
|
182 |
+
return tensor if pos is None else tensor + pos
|
183 |
+
|
184 |
+
def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
185 |
+
# q = k = self.with_pos_embed(src, pos)
|
186 |
+
src2 = self.self_attn(src, src, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
187 |
+
src = src + self.dropout1(src2)
|
188 |
+
src = self.norm1(src)
|
189 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
190 |
+
src = src + self.dropout2(src2)
|
191 |
+
src = self.norm2(src)
|
192 |
+
return src
|
193 |
+
|
194 |
+
def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
195 |
+
src2 = self.norm1(src)
|
196 |
+
# q = k = self.with_pos_embed(src2, pos)
|
197 |
+
src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
198 |
+
src = src + self.dropout1(src2)
|
199 |
+
src2 = self.norm2(src)
|
200 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
201 |
+
src = src + self.dropout2(src2)
|
202 |
+
return src
|
203 |
+
|
204 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
205 |
+
if self.normalize_before:
|
206 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
207 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
208 |
+
|
209 |
+
|
210 |
+
class TransformerDecoderLayer(nn.Module):
|
211 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False):
|
212 |
+
super().__init__()
|
213 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
214 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
215 |
+
# Implementation of Feedforward model
|
216 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
217 |
+
self.dropout = nn.Dropout(dropout)
|
218 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
219 |
+
|
220 |
+
self.norm1 = nn.LayerNorm(d_model)
|
221 |
+
self.norm2 = nn.LayerNorm(d_model)
|
222 |
+
self.norm3 = nn.LayerNorm(d_model)
|
223 |
+
self.dropout1 = nn.Dropout(dropout)
|
224 |
+
self.dropout2 = nn.Dropout(dropout)
|
225 |
+
self.dropout3 = nn.Dropout(dropout)
|
226 |
+
|
227 |
+
self.activation = _get_activation_fn(activation)
|
228 |
+
self.normalize_before = normalize_before
|
229 |
+
|
230 |
+
def with_pos_embed(self, tensor, pos):
|
231 |
+
return tensor if pos is None else tensor + pos
|
232 |
+
|
233 |
+
def forward_post(
|
234 |
+
self,
|
235 |
+
tgt,
|
236 |
+
memory,
|
237 |
+
tgt_mask=None,
|
238 |
+
memory_mask=None,
|
239 |
+
tgt_key_padding_mask=None,
|
240 |
+
memory_key_padding_mask=None,
|
241 |
+
pos=None,
|
242 |
+
query_pos=None,
|
243 |
+
):
|
244 |
+
# q = k = self.with_pos_embed(tgt, query_pos)
|
245 |
+
tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
246 |
+
tgt = tgt + self.dropout1(tgt2)
|
247 |
+
tgt = self.norm1(tgt)
|
248 |
+
tgt2 = self.multihead_attn(
|
249 |
+
query=tgt, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
|
250 |
+
)[0]
|
251 |
+
tgt = tgt + self.dropout2(tgt2)
|
252 |
+
tgt = self.norm2(tgt)
|
253 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
254 |
+
tgt = tgt + self.dropout3(tgt2)
|
255 |
+
tgt = self.norm3(tgt)
|
256 |
+
return tgt
|
257 |
+
|
258 |
+
def forward_pre(
|
259 |
+
self,
|
260 |
+
tgt,
|
261 |
+
memory,
|
262 |
+
tgt_mask=None,
|
263 |
+
memory_mask=None,
|
264 |
+
tgt_key_padding_mask=None,
|
265 |
+
memory_key_padding_mask=None,
|
266 |
+
pos=None,
|
267 |
+
query_pos=None,
|
268 |
+
):
|
269 |
+
tgt2 = self.norm1(tgt)
|
270 |
+
# q = k = self.with_pos_embed(tgt2, query_pos)
|
271 |
+
tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
272 |
+
tgt = tgt + self.dropout1(tgt2)
|
273 |
+
tgt2 = self.norm2(tgt)
|
274 |
+
tgt2 = self.multihead_attn(
|
275 |
+
query=tgt2, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
|
276 |
+
)[0]
|
277 |
+
tgt = tgt + self.dropout2(tgt2)
|
278 |
+
tgt2 = self.norm3(tgt)
|
279 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
280 |
+
tgt = tgt + self.dropout3(tgt2)
|
281 |
+
return tgt
|
282 |
+
|
283 |
+
def forward(
|
284 |
+
self,
|
285 |
+
tgt,
|
286 |
+
memory,
|
287 |
+
tgt_mask=None,
|
288 |
+
memory_mask=None,
|
289 |
+
tgt_key_padding_mask=None,
|
290 |
+
memory_key_padding_mask=None,
|
291 |
+
pos=None,
|
292 |
+
query_pos=None,
|
293 |
+
):
|
294 |
+
if self.normalize_before:
|
295 |
+
return self.forward_pre(
|
296 |
+
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
297 |
+
)
|
298 |
+
return self.forward_post(
|
299 |
+
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
300 |
+
)
|
core/utils.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from collections import defaultdict
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from scipy.io import loadmat
|
10 |
+
|
11 |
+
from configs.default import get_cfg_defaults
|
12 |
+
|
13 |
+
|
14 |
+
def _reset_parameters(model):
|
15 |
+
for p in model.parameters():
|
16 |
+
if p.dim() > 1:
|
17 |
+
nn.init.xavier_uniform_(p)
|
18 |
+
|
19 |
+
|
20 |
+
def get_video_style(video_name, style_type):
|
21 |
+
person_id, direction, emotion, level, *_ = video_name.split("_")
|
22 |
+
if style_type == "id_dir_emo_level":
|
23 |
+
style = "_".join([person_id, direction, emotion, level])
|
24 |
+
elif style_type == "emotion":
|
25 |
+
style = emotion
|
26 |
+
else:
|
27 |
+
raise ValueError("Unknown style type")
|
28 |
+
|
29 |
+
return style
|
30 |
+
|
31 |
+
|
32 |
+
def get_style_video_lists(video_list, style_type):
|
33 |
+
style2video_list = defaultdict(list)
|
34 |
+
for video in video_list:
|
35 |
+
style = get_video_style(video, style_type)
|
36 |
+
style2video_list[style].append(video)
|
37 |
+
|
38 |
+
return style2video_list
|
39 |
+
|
40 |
+
|
41 |
+
def get_face3d_clip(video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32):
|
42 |
+
"""_summary_
|
43 |
+
|
44 |
+
Args:
|
45 |
+
video_name (_type_): _description_
|
46 |
+
video_root_dir (_type_): _description_
|
47 |
+
num_frames (_type_): _description_
|
48 |
+
start_idx (_type_): "random" , middle, int
|
49 |
+
dtype (_type_, optional): _description_. Defaults to torch.float32.
|
50 |
+
|
51 |
+
Raises:
|
52 |
+
ValueError: _description_
|
53 |
+
ValueError: _description_
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
_type_: _description_
|
57 |
+
"""
|
58 |
+
video_path = os.path.join(video_root_dir, video_name)
|
59 |
+
if video_path[-3:] == "mat":
|
60 |
+
face3d_all = loadmat(video_path)["coeff"]
|
61 |
+
face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
|
62 |
+
elif video_path[-3:] == "txt":
|
63 |
+
face3d_exp = np.loadtxt(video_path)
|
64 |
+
else:
|
65 |
+
raise ValueError("Invalid 3DMM file extension")
|
66 |
+
|
67 |
+
length = face3d_exp.shape[0]
|
68 |
+
clip_num_frames = num_frames
|
69 |
+
if start_idx == "random":
|
70 |
+
clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
71 |
+
elif start_idx == "middle":
|
72 |
+
clip_start_idx = (length - clip_num_frames + 1) // 2
|
73 |
+
elif isinstance(start_idx, int):
|
74 |
+
clip_start_idx = start_idx
|
75 |
+
else:
|
76 |
+
raise ValueError(f"Invalid start_idx {start_idx}")
|
77 |
+
|
78 |
+
face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
79 |
+
face3d_clip = torch.tensor(face3d_clip, dtype=dtype)
|
80 |
+
|
81 |
+
return face3d_clip
|
82 |
+
|
83 |
+
|
84 |
+
def get_video_style_clip(video_path, style_max_len, start_idx="random", dtype=torch.float32):
|
85 |
+
if video_path[-3:] == "mat":
|
86 |
+
face3d_all = loadmat(video_path)["coeff"]
|
87 |
+
face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
|
88 |
+
elif video_path[-3:] == "txt":
|
89 |
+
face3d_exp = np.loadtxt(video_path)
|
90 |
+
else:
|
91 |
+
raise ValueError("Invalid 3DMM file extension")
|
92 |
+
|
93 |
+
face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
|
94 |
+
|
95 |
+
length = face3d_exp.shape[0]
|
96 |
+
if length >= style_max_len:
|
97 |
+
clip_num_frames = style_max_len
|
98 |
+
if start_idx == "random":
|
99 |
+
clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
100 |
+
elif start_idx == "middle":
|
101 |
+
clip_start_idx = (length - clip_num_frames + 1) // 2
|
102 |
+
elif isinstance(start_idx, int):
|
103 |
+
clip_start_idx = start_idx
|
104 |
+
else:
|
105 |
+
raise ValueError(f"Invalid start_idx {start_idx}")
|
106 |
+
|
107 |
+
face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
108 |
+
pad_mask = torch.tensor([False] * style_max_len)
|
109 |
+
else:
|
110 |
+
padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
|
111 |
+
face3d_clip = torch.cat((face3d_exp, padding), dim=0)
|
112 |
+
pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
|
113 |
+
|
114 |
+
return face3d_clip, pad_mask
|
115 |
+
|
116 |
+
|
117 |
+
def get_audio_name_from_video(video_name):
|
118 |
+
audio_name = video_name[:-4] + "_seq.json"
|
119 |
+
return audio_name
|
120 |
+
|
121 |
+
|
122 |
+
def get_audio_window(audio, win_size):
|
123 |
+
"""
|
124 |
+
|
125 |
+
Args:
|
126 |
+
audio (numpy.ndarray): (N,)
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
audio_wins (numpy.ndarray): (N, W)
|
130 |
+
"""
|
131 |
+
num_frames = len(audio)
|
132 |
+
ph_frames = []
|
133 |
+
for rid in range(0, num_frames):
|
134 |
+
ph = []
|
135 |
+
for i in range(rid - win_size, rid + win_size + 1):
|
136 |
+
if i < 0:
|
137 |
+
ph.append(31)
|
138 |
+
elif i >= num_frames:
|
139 |
+
ph.append(31)
|
140 |
+
else:
|
141 |
+
ph.append(audio[i])
|
142 |
+
|
143 |
+
ph_frames.append(ph)
|
144 |
+
|
145 |
+
audio_wins = np.array(ph_frames)
|
146 |
+
|
147 |
+
return audio_wins
|
148 |
+
|
149 |
+
|
150 |
+
def setup_config():
|
151 |
+
parser = argparse.ArgumentParser(description="voice2pose main program")
|
152 |
+
parser.add_argument("--config_file", default="", metavar="FILE", help="path to config file")
|
153 |
+
parser.add_argument("--resume_from", type=str, default=None, help="the checkpoint to resume from")
|
154 |
+
parser.add_argument("--test_only", action="store_true", help="perform testing and evaluation only")
|
155 |
+
parser.add_argument("--demo_input", type=str, default=None, help="path to input for demo")
|
156 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="the checkpoint to test with")
|
157 |
+
parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
|
158 |
+
parser.add_argument(
|
159 |
+
"opts",
|
160 |
+
help="Modify config options using the command-line",
|
161 |
+
default=None,
|
162 |
+
nargs=argparse.REMAINDER,
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--local_rank",
|
166 |
+
type=int,
|
167 |
+
help="local rank for DistributedDataParallel",
|
168 |
+
)
|
169 |
+
parser.add_argument(
|
170 |
+
"--master_port",
|
171 |
+
type=str,
|
172 |
+
default="12345",
|
173 |
+
)
|
174 |
+
args = parser.parse_args()
|
175 |
+
|
176 |
+
cfg = get_cfg_defaults()
|
177 |
+
cfg.merge_from_file(args.config_file)
|
178 |
+
cfg.merge_from_list(args.opts)
|
179 |
+
cfg.freeze()
|
180 |
+
return args, cfg
|
181 |
+
|
182 |
+
|
183 |
+
def setup_logger(base_path, exp_name):
|
184 |
+
rootLogger = logging.getLogger()
|
185 |
+
rootLogger.setLevel(logging.INFO)
|
186 |
+
|
187 |
+
logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")
|
188 |
+
|
189 |
+
log_path = "{0}/{1}.log".format(base_path, exp_name)
|
190 |
+
fileHandler = logging.FileHandler(log_path)
|
191 |
+
fileHandler.setFormatter(logFormatter)
|
192 |
+
rootLogger.addHandler(fileHandler)
|
193 |
+
|
194 |
+
consoleHandler = logging.StreamHandler()
|
195 |
+
consoleHandler.setFormatter(logFormatter)
|
196 |
+
rootLogger.addHandler(consoleHandler)
|
197 |
+
rootLogger.handlers[0].setLevel(logging.ERROR)
|
198 |
+
|
199 |
+
logging.info("log path: %s" % log_path)
|
200 |
+
|
201 |
+
|
202 |
+
def get_pose_params(mat_path):
|
203 |
+
"""Get pose parameters from mat file
|
204 |
+
|
205 |
+
Args:
|
206 |
+
mat_path (str): path of mat file
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
|
210 |
+
"""
|
211 |
+
mat_dict = loadmat(mat_path)
|
212 |
+
|
213 |
+
np_3dmm = mat_dict["coeff"]
|
214 |
+
angles = np_3dmm[:, 224:227]
|
215 |
+
translations = np_3dmm[:, 254:257]
|
216 |
+
|
217 |
+
np_trans_params = mat_dict["transform_params"]
|
218 |
+
crop = np_trans_params[:, -3:]
|
219 |
+
|
220 |
+
pose_params = np.concatenate((angles, translations, crop), axis=1)
|
221 |
+
|
222 |
+
return pose_params
|
223 |
+
|
224 |
+
|
225 |
+
def obtain_seq_index(index, num_frames, radius):
|
226 |
+
seq = list(range(index - radius, index + radius + 1))
|
227 |
+
seq = [min(max(item, 0), num_frames - 1) for item in seq]
|
228 |
+
return seq
|
demo.mp4
ADDED
Binary file (547 kB). View file
|
|
demo.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d29bc2d048a0be69193daea065734a3e76abbbe37e5ae4c8903d82f14ad92cb
|
3 |
+
size 227888
|
demo_download.mp4
ADDED
Binary file (457 kB). View file
|
|
demo_download.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d29bc2d048a0be69193daea065734a3e76abbbe37e5ae4c8903d82f14ad92cb
|
3 |
+
size 227888
|
env.yaml
ADDED
File without changes
|
environment.yml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: styletalk
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- conda-forge
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- _libgcc_mutex=0.1=main
|
8 |
+
- _openmp_mutex=5.1=1_gnu
|
9 |
+
- blas=1.0=mkl
|
10 |
+
- bzip2=1.0.8=h7b6447c_0
|
11 |
+
- ca-certificates=2023.08.22=h06a4308_0
|
12 |
+
- certifi=2020.6.20=pyhd3eb1b0_3
|
13 |
+
- cudatoolkit=11.1.1=ha002fc5_10
|
14 |
+
- ffmpeg=4.2.2=h20bf706_0
|
15 |
+
- freetype=2.12.1=h4a9f257_0
|
16 |
+
- gmp=6.2.1=h295c915_3
|
17 |
+
- gnutls=3.6.15=he1e5248_0
|
18 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
19 |
+
- jpeg=9b=h024ee3a_2
|
20 |
+
- lame=3.100=h7b6447c_0
|
21 |
+
- libedit=3.1.20221030=h5eee18b_0
|
22 |
+
- libffi=3.2.1=hf484d3e_1007
|
23 |
+
- libgcc-ng=11.2.0=h1234567_1
|
24 |
+
- libgomp=11.2.0=h1234567_1
|
25 |
+
- libidn2=2.3.4=h5eee18b_0
|
26 |
+
- libopus=1.3.1=h7b6447c_0
|
27 |
+
- libpng=1.6.39=h5eee18b_0
|
28 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
29 |
+
- libtasn1=4.19.0=h5eee18b_0
|
30 |
+
- libtiff=4.0.9=he6b73bb_1
|
31 |
+
- libunistring=0.9.10=h27cfd23_0
|
32 |
+
- libuv=1.44.2=h5eee18b_0
|
33 |
+
- libvpx=1.7.0=h439df22_0
|
34 |
+
- mkl=2021.4.0=h06a4308_640
|
35 |
+
- mkl-service=2.4.0=py37h7f8727e_0
|
36 |
+
- mkl_fft=1.3.1=py37h3e078e5_1
|
37 |
+
- mkl_random=1.2.2=py37h51133e4_0
|
38 |
+
- ncurses=6.4=h6a678d5_0
|
39 |
+
- nettle=3.7.3=hbbd107a_1
|
40 |
+
- ninja=1.10.2=h06a4308_5
|
41 |
+
- ninja-base=1.10.2=hd09550d_5
|
42 |
+
- numpy=1.21.5=py37h6c91a56_3
|
43 |
+
- numpy-base=1.21.5=py37ha15fc14_3
|
44 |
+
- olefile=0.46=py37_0
|
45 |
+
- openh264=2.1.1=h4ff587b_0
|
46 |
+
- openssl=1.0.2u=h7b6447c_0
|
47 |
+
- pip=22.3.1=py37h06a4308_0
|
48 |
+
- python=3.7.0=h6e4f718_3
|
49 |
+
- python_abi=3.7=2_cp37m
|
50 |
+
- pytorch=1.8.0=py3.7_cuda11.1_cudnn8.0.5_0
|
51 |
+
- readline=7.0=h7b6447c_5
|
52 |
+
- setuptools=65.6.3=py37h06a4308_0
|
53 |
+
- six=1.16.0=pyhd3eb1b0_1
|
54 |
+
- sqlite=3.33.0=h62c20be_0
|
55 |
+
- tk=8.6.12=h1ccaba5_0
|
56 |
+
- torchaudio=0.8.0=py37
|
57 |
+
- torchvision=0.9.0=py37_cu111
|
58 |
+
- typing_extensions=4.1.1=pyh06a4308_0
|
59 |
+
- wheel=0.38.4=py37h06a4308_0
|
60 |
+
- x264=1!157.20191217=h7b6447c_0
|
61 |
+
- xz=5.4.2=h5eee18b_0
|
62 |
+
- zlib=1.2.13=h5eee18b_0
|
63 |
+
- pip:
|
64 |
+
- av==10.0.0
|
65 |
+
- beautifulsoup4==4.12.2
|
66 |
+
- charset-normalizer==3.3.2
|
67 |
+
- ffmpeg-python==0.2.0
|
68 |
+
- filelock==3.12.2
|
69 |
+
- future==0.18.3
|
70 |
+
- gdown==4.7.1
|
71 |
+
- idna==3.6
|
72 |
+
- imageio==2.18.0
|
73 |
+
- joblib==1.3.2
|
74 |
+
- networkx==2.6.3
|
75 |
+
- opencv-python==4.4.0.46
|
76 |
+
- packaging==23.2
|
77 |
+
- pillow==9.1.0
|
78 |
+
- pysocks==1.7.1
|
79 |
+
- pywavelets==1.3.0
|
80 |
+
- pyyaml==6.0
|
81 |
+
- requests==2.31.0
|
82 |
+
- scikit-image==0.19.3
|
83 |
+
- scikit-learn==1.0.2
|
84 |
+
- scipy==1.7.3
|
85 |
+
- soupsieve==2.4.1
|
86 |
+
- threadpoolctl==3.1.0
|
87 |
+
- tifffile==2021.11.2
|
88 |
+
- tqdm==4.66.1
|
89 |
+
- urllib3==2.0.7
|
90 |
+
- yacs==0.1.8
|
91 |
+
prefix: /home/pixis/miniconda3/envs/styletalk
|
generators/__pycache__/base_function.cpython-37.pyc
ADDED
Binary file (13.9 kB). View file
|
|
generators/__pycache__/face_model.cpython-37.pyc
ADDED
Binary file (3.97 kB). View file
|
|
generators/__pycache__/flow_util.cpython-37.pyc
ADDED
Binary file (1.95 kB). View file
|
|
generators/base_function.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torch.autograd import Function
|
8 |
+
from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
|
9 |
+
|
10 |
+
|
11 |
+
class LayerNorm2d(nn.Module):
|
12 |
+
def __init__(self, n_out, affine=True):
|
13 |
+
super(LayerNorm2d, self).__init__()
|
14 |
+
self.n_out = n_out
|
15 |
+
self.affine = affine
|
16 |
+
|
17 |
+
if self.affine:
|
18 |
+
self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
|
19 |
+
self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
normalized_shape = x.size()[1:]
|
23 |
+
if self.affine:
|
24 |
+
return F.layer_norm(x, normalized_shape, \
|
25 |
+
self.weight.expand(normalized_shape),
|
26 |
+
self.bias.expand(normalized_shape))
|
27 |
+
|
28 |
+
else:
|
29 |
+
return F.layer_norm(x, normalized_shape)
|
30 |
+
|
31 |
+
class ADAINHourglass(nn.Module):
|
32 |
+
def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
|
33 |
+
super(ADAINHourglass, self).__init__()
|
34 |
+
self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
|
35 |
+
self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
|
36 |
+
self.output_nc = self.decoder.output_nc
|
37 |
+
|
38 |
+
def forward(self, x, z):
|
39 |
+
return self.decoder(self.encoder(x, z), z)
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
class ADAINEncoder(nn.Module):
|
44 |
+
def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
45 |
+
super(ADAINEncoder, self).__init__()
|
46 |
+
self.layers = layers
|
47 |
+
self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
|
48 |
+
for i in range(layers):
|
49 |
+
in_channels = min(ngf * (2**i), img_f)
|
50 |
+
out_channels = min(ngf *(2**(i+1)), img_f)
|
51 |
+
model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
|
52 |
+
setattr(self, 'encoder' + str(i), model)
|
53 |
+
self.output_nc = out_channels
|
54 |
+
|
55 |
+
def forward(self, x, z):
|
56 |
+
out = self.input_layer(x)
|
57 |
+
out_list = [out]
|
58 |
+
for i in range(self.layers):
|
59 |
+
model = getattr(self, 'encoder' + str(i))
|
60 |
+
out = model(out, z)
|
61 |
+
out_list.append(out)
|
62 |
+
return out_list
|
63 |
+
|
64 |
+
class ADAINDecoder(nn.Module):
|
65 |
+
"""docstring for ADAINDecoder"""
|
66 |
+
def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
|
67 |
+
nonlinearity=nn.LeakyReLU(), use_spect=False):
|
68 |
+
|
69 |
+
super(ADAINDecoder, self).__init__()
|
70 |
+
self.encoder_layers = encoder_layers
|
71 |
+
self.decoder_layers = decoder_layers
|
72 |
+
self.skip_connect = skip_connect
|
73 |
+
use_transpose = True
|
74 |
+
|
75 |
+
for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
|
76 |
+
in_channels = min(ngf * (2**(i+1)), img_f)
|
77 |
+
in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
|
78 |
+
out_channels = min(ngf * (2**i), img_f)
|
79 |
+
model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
|
80 |
+
setattr(self, 'decoder' + str(i), model)
|
81 |
+
|
82 |
+
self.output_nc = out_channels*2 if self.skip_connect else out_channels
|
83 |
+
|
84 |
+
def forward(self, x, z):
|
85 |
+
out = x.pop() if self.skip_connect else x
|
86 |
+
for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
|
87 |
+
model = getattr(self, 'decoder' + str(i))
|
88 |
+
out = model(out, z)
|
89 |
+
out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
|
90 |
+
return out
|
91 |
+
|
92 |
+
class ADAINEncoderBlock(nn.Module):
|
93 |
+
def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
94 |
+
super(ADAINEncoderBlock, self).__init__()
|
95 |
+
kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
|
96 |
+
kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
97 |
+
|
98 |
+
self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
|
99 |
+
self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
|
100 |
+
|
101 |
+
|
102 |
+
self.norm_0 = ADAIN(input_nc, feature_nc)
|
103 |
+
self.norm_1 = ADAIN(output_nc, feature_nc)
|
104 |
+
self.actvn = nonlinearity
|
105 |
+
|
106 |
+
def forward(self, x, z):
|
107 |
+
x = self.conv_0(self.actvn(self.norm_0(x, z)))
|
108 |
+
x = self.conv_1(self.actvn(self.norm_1(x, z)))
|
109 |
+
return x
|
110 |
+
|
111 |
+
class ADAINDecoderBlock(nn.Module):
|
112 |
+
def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
113 |
+
super(ADAINDecoderBlock, self).__init__()
|
114 |
+
# Attributes
|
115 |
+
self.actvn = nonlinearity
|
116 |
+
hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
|
117 |
+
|
118 |
+
kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
|
119 |
+
if use_transpose:
|
120 |
+
kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
|
121 |
+
else:
|
122 |
+
kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
|
123 |
+
|
124 |
+
# create conv layers
|
125 |
+
self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
|
126 |
+
if use_transpose:
|
127 |
+
self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
|
128 |
+
self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
|
129 |
+
else:
|
130 |
+
self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
|
131 |
+
nn.Upsample(scale_factor=2))
|
132 |
+
self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
|
133 |
+
nn.Upsample(scale_factor=2))
|
134 |
+
# define normalization layers
|
135 |
+
self.norm_0 = ADAIN(input_nc, feature_nc)
|
136 |
+
self.norm_1 = ADAIN(hidden_nc, feature_nc)
|
137 |
+
self.norm_s = ADAIN(input_nc, feature_nc)
|
138 |
+
|
139 |
+
def forward(self, x, z):
|
140 |
+
x_s = self.shortcut(x, z)
|
141 |
+
dx = self.conv_0(self.actvn(self.norm_0(x, z)))
|
142 |
+
dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
|
143 |
+
out = x_s + dx
|
144 |
+
return out
|
145 |
+
|
146 |
+
def shortcut(self, x, z):
|
147 |
+
x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
|
148 |
+
return x_s
|
149 |
+
|
150 |
+
|
151 |
+
def spectral_norm(module, use_spect=True):
|
152 |
+
"""use spectral normal layer to stable the training process"""
|
153 |
+
if use_spect:
|
154 |
+
return SpectralNorm(module)
|
155 |
+
else:
|
156 |
+
return module
|
157 |
+
|
158 |
+
|
159 |
+
class ADAIN(nn.Module):
|
160 |
+
def __init__(self, norm_nc, feature_nc):
|
161 |
+
super().__init__()
|
162 |
+
|
163 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
164 |
+
|
165 |
+
nhidden = 128
|
166 |
+
use_bias=True
|
167 |
+
|
168 |
+
self.mlp_shared = nn.Sequential(
|
169 |
+
nn.Linear(feature_nc, nhidden, bias=use_bias),
|
170 |
+
nn.ReLU()
|
171 |
+
)
|
172 |
+
self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
|
173 |
+
self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
|
174 |
+
|
175 |
+
def forward(self, x, feature):
|
176 |
+
|
177 |
+
# Part 1. generate parameter-free normalized activations
|
178 |
+
normalized = self.param_free_norm(x)
|
179 |
+
|
180 |
+
# Part 2. produce scaling and bias conditioned on feature
|
181 |
+
feature = feature.view(feature.size(0), -1)
|
182 |
+
actv = self.mlp_shared(feature)
|
183 |
+
gamma = self.mlp_gamma(actv)
|
184 |
+
beta = self.mlp_beta(actv)
|
185 |
+
|
186 |
+
# apply scale and bias
|
187 |
+
gamma = gamma.view(*gamma.size()[:2], 1,1)
|
188 |
+
beta = beta.view(*beta.size()[:2], 1,1)
|
189 |
+
out = normalized * (1 + gamma) + beta
|
190 |
+
return out
|
191 |
+
|
192 |
+
|
193 |
+
class FineEncoder(nn.Module):
|
194 |
+
"""docstring for Encoder"""
|
195 |
+
def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
196 |
+
super(FineEncoder, self).__init__()
|
197 |
+
self.layers = layers
|
198 |
+
self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
|
199 |
+
for i in range(layers):
|
200 |
+
in_channels = min(ngf*(2**i), img_f)
|
201 |
+
out_channels = min(ngf*(2**(i+1)), img_f)
|
202 |
+
model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
203 |
+
setattr(self, 'down' + str(i), model)
|
204 |
+
self.output_nc = out_channels
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
x = self.first(x)
|
208 |
+
out=[x]
|
209 |
+
for i in range(self.layers):
|
210 |
+
model = getattr(self, 'down'+str(i))
|
211 |
+
x = model(x)
|
212 |
+
out.append(x)
|
213 |
+
return out
|
214 |
+
|
215 |
+
class FineDecoder(nn.Module):
|
216 |
+
"""docstring for FineDecoder"""
|
217 |
+
def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
218 |
+
super(FineDecoder, self).__init__()
|
219 |
+
self.layers = layers
|
220 |
+
for i in range(layers)[::-1]:
|
221 |
+
in_channels = min(ngf*(2**(i+1)), img_f)
|
222 |
+
out_channels = min(ngf*(2**i), img_f)
|
223 |
+
up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
224 |
+
res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
|
225 |
+
jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
|
226 |
+
|
227 |
+
setattr(self, 'up' + str(i), up)
|
228 |
+
setattr(self, 'res' + str(i), res)
|
229 |
+
setattr(self, 'jump' + str(i), jump)
|
230 |
+
|
231 |
+
self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
|
232 |
+
|
233 |
+
self.output_nc = out_channels
|
234 |
+
|
235 |
+
def forward(self, x, z):
|
236 |
+
out = x.pop()
|
237 |
+
for i in range(self.layers)[::-1]:
|
238 |
+
res_model = getattr(self, 'res' + str(i))
|
239 |
+
up_model = getattr(self, 'up' + str(i))
|
240 |
+
jump_model = getattr(self, 'jump' + str(i))
|
241 |
+
out = res_model(out, z)
|
242 |
+
out = up_model(out)
|
243 |
+
out = jump_model(x.pop()) + out
|
244 |
+
out_image = self.final(out)
|
245 |
+
return out_image
|
246 |
+
|
247 |
+
class FirstBlock2d(nn.Module):
|
248 |
+
"""
|
249 |
+
Downsampling block for use in encoder.
|
250 |
+
"""
|
251 |
+
def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
252 |
+
super(FirstBlock2d, self).__init__()
|
253 |
+
kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
|
254 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
255 |
+
|
256 |
+
if type(norm_layer) == type(None):
|
257 |
+
self.model = nn.Sequential(conv, nonlinearity)
|
258 |
+
else:
|
259 |
+
self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
|
260 |
+
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
out = self.model(x)
|
264 |
+
return out
|
265 |
+
|
266 |
+
class DownBlock2d(nn.Module):
|
267 |
+
def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
268 |
+
super(DownBlock2d, self).__init__()
|
269 |
+
|
270 |
+
|
271 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
272 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
273 |
+
pool = nn.AvgPool2d(kernel_size=(2, 2))
|
274 |
+
|
275 |
+
if type(norm_layer) == type(None):
|
276 |
+
self.model = nn.Sequential(conv, nonlinearity, pool)
|
277 |
+
else:
|
278 |
+
self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
out = self.model(x)
|
282 |
+
return out
|
283 |
+
|
284 |
+
class UpBlock2d(nn.Module):
|
285 |
+
def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
286 |
+
super(UpBlock2d, self).__init__()
|
287 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
288 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
289 |
+
if type(norm_layer) == type(None):
|
290 |
+
self.model = nn.Sequential(conv, nonlinearity)
|
291 |
+
else:
|
292 |
+
self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
|
293 |
+
|
294 |
+
def forward(self, x):
|
295 |
+
out = self.model(F.interpolate(x, scale_factor=2))
|
296 |
+
return out
|
297 |
+
|
298 |
+
class FineADAINResBlocks(nn.Module):
|
299 |
+
def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
300 |
+
super(FineADAINResBlocks, self).__init__()
|
301 |
+
self.num_block = num_block
|
302 |
+
for i in range(num_block):
|
303 |
+
model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
|
304 |
+
setattr(self, 'res'+str(i), model)
|
305 |
+
|
306 |
+
def forward(self, x, z):
|
307 |
+
for i in range(self.num_block):
|
308 |
+
model = getattr(self, 'res'+str(i))
|
309 |
+
x = model(x, z)
|
310 |
+
return x
|
311 |
+
|
312 |
+
class Jump(nn.Module):
|
313 |
+
def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
314 |
+
super(Jump, self).__init__()
|
315 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
316 |
+
conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
|
317 |
+
|
318 |
+
if type(norm_layer) == type(None):
|
319 |
+
self.model = nn.Sequential(conv, nonlinearity)
|
320 |
+
else:
|
321 |
+
self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
|
322 |
+
|
323 |
+
def forward(self, x):
|
324 |
+
out = self.model(x)
|
325 |
+
return out
|
326 |
+
|
327 |
+
class FineADAINResBlock2d(nn.Module):
|
328 |
+
"""
|
329 |
+
Define an Residual block for different types
|
330 |
+
"""
|
331 |
+
def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
332 |
+
super(FineADAINResBlock2d, self).__init__()
|
333 |
+
|
334 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
335 |
+
|
336 |
+
self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
|
337 |
+
self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
|
338 |
+
self.norm1 = ADAIN(input_nc, feature_nc)
|
339 |
+
self.norm2 = ADAIN(input_nc, feature_nc)
|
340 |
+
|
341 |
+
self.actvn = nonlinearity
|
342 |
+
|
343 |
+
|
344 |
+
def forward(self, x, z):
|
345 |
+
dx = self.actvn(self.norm1(self.conv1(x), z))
|
346 |
+
dx = self.norm2(self.conv2(x), z)
|
347 |
+
out = dx + x
|
348 |
+
return out
|
349 |
+
|
350 |
+
class FinalBlock2d(nn.Module):
|
351 |
+
"""
|
352 |
+
Define the output layer
|
353 |
+
"""
|
354 |
+
def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
|
355 |
+
super(FinalBlock2d, self).__init__()
|
356 |
+
|
357 |
+
kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
|
358 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
359 |
+
|
360 |
+
if tanh_or_sigmoid == 'sigmoid':
|
361 |
+
out_nonlinearity = nn.Sigmoid()
|
362 |
+
else:
|
363 |
+
out_nonlinearity = nn.Tanh()
|
364 |
+
|
365 |
+
self.model = nn.Sequential(conv, out_nonlinearity)
|
366 |
+
def forward(self, x):
|
367 |
+
out = self.model(x)
|
368 |
+
return out
|
generators/face_model.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
import generators.flow_util as flow_util
|
9 |
+
from generators.base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
|
10 |
+
|
11 |
+
class FaceGenerator(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
mapping_net,
|
15 |
+
warpping_net,
|
16 |
+
editing_net,
|
17 |
+
common
|
18 |
+
):
|
19 |
+
super(FaceGenerator, self).__init__()
|
20 |
+
self.mapping_net = MappingNet(**mapping_net)
|
21 |
+
self.warpping_net = WarpingNet(**warpping_net, **common)
|
22 |
+
self.editing_net = EditingNet(**editing_net, **common)
|
23 |
+
|
24 |
+
def forward(
|
25 |
+
self,
|
26 |
+
input_image,
|
27 |
+
driving_source,
|
28 |
+
stage=None
|
29 |
+
):
|
30 |
+
if stage == 'warp':
|
31 |
+
descriptor = self.mapping_net(driving_source)
|
32 |
+
output = self.warpping_net(input_image, descriptor)
|
33 |
+
else:
|
34 |
+
descriptor = self.mapping_net(driving_source)
|
35 |
+
output = self.warpping_net(input_image, descriptor)
|
36 |
+
output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
|
37 |
+
return output
|
38 |
+
|
39 |
+
class MappingNet(nn.Module):
|
40 |
+
def __init__(self, coeff_nc, descriptor_nc, layer):
|
41 |
+
super( MappingNet, self).__init__()
|
42 |
+
|
43 |
+
self.layer = layer
|
44 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
45 |
+
|
46 |
+
self.first = nn.Sequential(
|
47 |
+
torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
|
48 |
+
|
49 |
+
for i in range(layer):
|
50 |
+
net = nn.Sequential(nonlinearity,
|
51 |
+
torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
|
52 |
+
setattr(self, 'encoder' + str(i), net)
|
53 |
+
|
54 |
+
self.pooling = nn.AdaptiveAvgPool1d(1)
|
55 |
+
self.output_nc = descriptor_nc
|
56 |
+
|
57 |
+
def forward(self, input_3dmm):
|
58 |
+
out = self.first(input_3dmm)
|
59 |
+
for i in range(self.layer):
|
60 |
+
model = getattr(self, 'encoder' + str(i))
|
61 |
+
out = model(out) + out[:,:,3:-3]
|
62 |
+
out = self.pooling(out)
|
63 |
+
return out
|
64 |
+
|
65 |
+
class WarpingNet(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
image_nc,
|
69 |
+
descriptor_nc,
|
70 |
+
base_nc,
|
71 |
+
max_nc,
|
72 |
+
encoder_layer,
|
73 |
+
decoder_layer,
|
74 |
+
use_spect
|
75 |
+
):
|
76 |
+
super( WarpingNet, self).__init__()
|
77 |
+
|
78 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
79 |
+
norm_layer = functools.partial(LayerNorm2d, affine=True)
|
80 |
+
kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
|
81 |
+
|
82 |
+
self.descriptor_nc = descriptor_nc
|
83 |
+
self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
|
84 |
+
max_nc, encoder_layer, decoder_layer, **kwargs)
|
85 |
+
|
86 |
+
self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
|
87 |
+
nonlinearity,
|
88 |
+
nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
|
89 |
+
|
90 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
91 |
+
|
92 |
+
def forward(self, input_image, descriptor):
|
93 |
+
final_output={}
|
94 |
+
output = self.hourglass(input_image, descriptor)
|
95 |
+
final_output['flow_field'] = self.flow_out(output)
|
96 |
+
|
97 |
+
deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
|
98 |
+
final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
|
99 |
+
return final_output
|
100 |
+
|
101 |
+
|
102 |
+
class EditingNet(nn.Module):
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
image_nc,
|
106 |
+
descriptor_nc,
|
107 |
+
layer,
|
108 |
+
base_nc,
|
109 |
+
max_nc,
|
110 |
+
num_res_blocks,
|
111 |
+
use_spect):
|
112 |
+
super(EditingNet, self).__init__()
|
113 |
+
|
114 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
115 |
+
norm_layer = functools.partial(LayerNorm2d, affine=True)
|
116 |
+
kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
|
117 |
+
self.descriptor_nc = descriptor_nc
|
118 |
+
|
119 |
+
# encoder part
|
120 |
+
self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
|
121 |
+
self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
|
122 |
+
|
123 |
+
def forward(self, input_image, warp_image, descriptor):
|
124 |
+
x = torch.cat([input_image, warp_image], 1)
|
125 |
+
x = self.encoder(x)
|
126 |
+
gen_image = self.decoder(x, descriptor)
|
127 |
+
return gen_image
|
generators/flow_util.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def convert_flow_to_deformation(flow):
|
4 |
+
r"""convert flow fields to deformations.
|
5 |
+
|
6 |
+
Args:
|
7 |
+
flow (tensor): Flow field obtained by the model
|
8 |
+
Returns:
|
9 |
+
deformation (tensor): The deformation used for warpping
|
10 |
+
"""
|
11 |
+
b,c,h,w = flow.shape
|
12 |
+
flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
|
13 |
+
grid = make_coordinate_grid(flow)
|
14 |
+
deformation = grid + flow_norm.permute(0,2,3,1)
|
15 |
+
return deformation
|
16 |
+
|
17 |
+
def make_coordinate_grid(flow):
|
18 |
+
r"""obtain coordinate grid with the same size as the flow filed.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
flow (tensor): Flow field obtained by the model
|
22 |
+
Returns:
|
23 |
+
grid (tensor): The grid with the same size as the input flow
|
24 |
+
"""
|
25 |
+
b,c,h,w = flow.shape
|
26 |
+
|
27 |
+
x = torch.arange(w).to(flow)
|
28 |
+
y = torch.arange(h).to(flow)
|
29 |
+
|
30 |
+
x = (2 * (x / (w - 1)) - 1)
|
31 |
+
y = (2 * (y / (h - 1)) - 1)
|
32 |
+
|
33 |
+
yy = y.view(-1, 1).repeat(1, w)
|
34 |
+
xx = x.view(1, -1).repeat(h, 1)
|
35 |
+
|
36 |
+
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
37 |
+
meshed = meshed.expand(b, -1, -1, -1)
|
38 |
+
return meshed
|
39 |
+
|
40 |
+
|
41 |
+
def warp_image(source_image, deformation):
|
42 |
+
r"""warp the input image according to the deformation
|
43 |
+
|
44 |
+
Args:
|
45 |
+
source_image (tensor): source images to be warpped
|
46 |
+
deformation (tensor): deformations used to warp the images; value in range (-1, 1)
|
47 |
+
Returns:
|
48 |
+
output (tensor): the warpped images
|
49 |
+
"""
|
50 |
+
_, h_old, w_old, _ = deformation.shape
|
51 |
+
_, _, h, w = source_image.shape
|
52 |
+
if h_old != h or w_old != w:
|
53 |
+
deformation = deformation.permute(0, 3, 1, 2)
|
54 |
+
deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
|
55 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
56 |
+
return torch.nn.functional.grid_sample(source_image, deformation)
|
inference_for_demo.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torchvision
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from core.networks.styletalk import StyleTalk
|
13 |
+
from core.utils import get_audio_window, get_pose_params, get_video_style_clip, obtain_seq_index
|
14 |
+
from configs.default import get_cfg_defaults
|
15 |
+
|
16 |
+
|
17 |
+
@torch.no_grad()
|
18 |
+
def get_eval_model(cfg):
|
19 |
+
model = StyleTalk(cfg).cuda()
|
20 |
+
content_encoder = model.content_encoder
|
21 |
+
style_encoder = model.style_encoder
|
22 |
+
decoder = model.decoder
|
23 |
+
checkpoint = torch.load(cfg.INFERENCE.CHECKPOINT)
|
24 |
+
model_state_dict = checkpoint["model_state_dict"]
|
25 |
+
content_encoder_dict = {k[16:]: v for k, v in model_state_dict.items() if k[:16] == "content_encoder."}
|
26 |
+
content_encoder.load_state_dict(content_encoder_dict, strict=True)
|
27 |
+
style_encoder_dict = {k[14:]: v for k, v in model_state_dict.items() if k[:14] == "style_encoder."}
|
28 |
+
style_encoder.load_state_dict(style_encoder_dict, strict=True)
|
29 |
+
decoder_dict = {k[8:]: v for k, v in model_state_dict.items() if k[:8] == "decoder."}
|
30 |
+
decoder.load_state_dict(decoder_dict, strict=True)
|
31 |
+
model.eval()
|
32 |
+
return content_encoder, style_encoder, decoder
|
33 |
+
|
34 |
+
|
35 |
+
@torch.no_grad()
|
36 |
+
def render_video(
|
37 |
+
net_G, src_img_path, exp_path, wav_path, output_path, silent=False, semantic_radius=13, fps=30, split_size=64
|
38 |
+
):
|
39 |
+
|
40 |
+
target_exp_seq = np.load(exp_path)
|
41 |
+
|
42 |
+
frame = cv2.imread(src_img_path)
|
43 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
44 |
+
src_img_raw = Image.fromarray(frame)
|
45 |
+
image_transform = transforms.Compose(
|
46 |
+
[
|
47 |
+
transforms.ToTensor(),
|
48 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
|
49 |
+
]
|
50 |
+
)
|
51 |
+
src_img = image_transform(src_img_raw)
|
52 |
+
|
53 |
+
target_win_exps = []
|
54 |
+
for frame_idx in range(len(target_exp_seq)):
|
55 |
+
win_indices = obtain_seq_index(frame_idx, target_exp_seq.shape[0], semantic_radius)
|
56 |
+
win_exp = torch.tensor(target_exp_seq[win_indices]).permute(1, 0)
|
57 |
+
# (73, 27)
|
58 |
+
target_win_exps.append(win_exp)
|
59 |
+
|
60 |
+
target_exp_concat = torch.stack(target_win_exps, dim=0)
|
61 |
+
target_splited_exps = torch.split(target_exp_concat, split_size, dim=0)
|
62 |
+
output_imgs = []
|
63 |
+
for win_exp in target_splited_exps:
|
64 |
+
win_exp = win_exp.cuda()
|
65 |
+
cur_src_img = src_img.expand(win_exp.shape[0], -1, -1, -1).cuda()
|
66 |
+
output_dict = net_G(cur_src_img, win_exp)
|
67 |
+
output_imgs.append(output_dict["fake_image"].cpu().clamp_(-1, 1))
|
68 |
+
|
69 |
+
output_imgs = torch.cat(output_imgs, 0)
|
70 |
+
transformed_imgs = ((output_imgs + 1) / 2 * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
71 |
+
|
72 |
+
if silent:
|
73 |
+
torchvision.io.write_video(output_path, transformed_imgs.cpu(), fps)
|
74 |
+
else:
|
75 |
+
silent_video_path = "silent.mp4"
|
76 |
+
torchvision.io.write_video(silent_video_path, transformed_imgs.cpu(), fps)
|
77 |
+
os.system(f"ffmpeg -loglevel quiet -y -i {silent_video_path} -i {wav_path} -shortest {output_path}")
|
78 |
+
os.remove(silent_video_path)
|
79 |
+
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def get_netG(checkpoint_path):
|
83 |
+
from generators.face_model import FaceGenerator
|
84 |
+
import yaml
|
85 |
+
|
86 |
+
with open("configs/renderer_conf.yaml", "r") as f:
|
87 |
+
renderer_config = yaml.load(f, Loader=yaml.FullLoader)
|
88 |
+
|
89 |
+
renderer = FaceGenerator(**renderer_config).to(torch.cuda.current_device())
|
90 |
+
|
91 |
+
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
92 |
+
renderer.load_state_dict(checkpoint["net_G_ema"], strict=False)
|
93 |
+
|
94 |
+
renderer.eval()
|
95 |
+
|
96 |
+
return renderer
|
97 |
+
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def generate_expression_params(
|
101 |
+
cfg, audio_path, style_clip_path, pose_path, output_path, content_encoder, style_encoder, decoder
|
102 |
+
):
|
103 |
+
with open(audio_path, "r") as f:
|
104 |
+
audio = json.load(f)
|
105 |
+
|
106 |
+
audio_win = get_audio_window(audio, cfg.WIN_SIZE)
|
107 |
+
audio_win = torch.tensor(audio_win).cuda()
|
108 |
+
content = content_encoder(audio_win.unsqueeze(0))
|
109 |
+
|
110 |
+
style_clip, pad_mask = get_video_style_clip(style_clip_path, style_max_len=256, start_idx=0)
|
111 |
+
style_code = style_encoder(
|
112 |
+
style_clip.unsqueeze(0).cuda(), pad_mask.unsqueeze(0).cuda() if pad_mask is not None else None
|
113 |
+
)
|
114 |
+
|
115 |
+
gen_exp_stack = decoder(content, style_code)
|
116 |
+
gen_exp = gen_exp_stack[0].cpu().numpy()
|
117 |
+
|
118 |
+
pose_ext = pose_path[-3:]
|
119 |
+
pose = None
|
120 |
+
if pose_ext == "npy":
|
121 |
+
pose = np.load(pose_path)
|
122 |
+
elif pose_ext == "mat":
|
123 |
+
pose = get_pose_params(pose_path)
|
124 |
+
# (L, 9)
|
125 |
+
|
126 |
+
selected_pose = None
|
127 |
+
if len(pose) >= len(gen_exp):
|
128 |
+
selected_pose = pose[: len(gen_exp)]
|
129 |
+
else:
|
130 |
+
selected_pose = pose[-1].unsqueeze(0).repeat(len(gen_exp), 1)
|
131 |
+
selected_pose[: len(pose)] = pose
|
132 |
+
|
133 |
+
gen_exp_pose = np.concatenate((gen_exp, selected_pose), axis=1)
|
134 |
+
np.save(output_path, gen_exp_pose)
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
parser = argparse.ArgumentParser(description="inference for demo")
|
139 |
+
parser.add_argument(
|
140 |
+
"--styletalk_checkpoint",
|
141 |
+
type=str,
|
142 |
+
default="checkpoints/styletalk_checkpoint.pth",
|
143 |
+
help="the checkpoint to test with",
|
144 |
+
)
|
145 |
+
parser.add_argument(
|
146 |
+
"--renderer_checkpoint",
|
147 |
+
type=str,
|
148 |
+
default="checkpoints/renderer_checkpoint.pt",
|
149 |
+
help="renderer checkpoint",
|
150 |
+
)
|
151 |
+
parser.add_argument("--audio_path", type=str, default="", help="path for phoneme")
|
152 |
+
parser.add_argument("--style_clip_path", type=str, default="", help="path for style_clip_mat")
|
153 |
+
parser.add_argument("--pose_path", type=str, default="", help="path for pose")
|
154 |
+
parser.add_argument("--src_img_path", type=str, default="test_images/KristiNoem1_0.jpg")
|
155 |
+
parser.add_argument("--wav_path", type=str, default="demo/data/KristiNoem_front_neutral_level1_002.wav")
|
156 |
+
parser.add_argument("--output_path", type=str, default="demo_output.npy", help="path for output")
|
157 |
+
args = parser.parse_args()
|
158 |
+
|
159 |
+
cfg = get_cfg_defaults()
|
160 |
+
cfg.INFERENCE.CHECKPOINT = args.styletalk_checkpoint
|
161 |
+
cfg.freeze()
|
162 |
+
print(f"checkpoint: {cfg.INFERENCE.CHECKPOINT}")
|
163 |
+
|
164 |
+
# load checkpoint
|
165 |
+
with torch.no_grad():
|
166 |
+
content_encoder, style_encoder, decoder = get_eval_model(cfg)
|
167 |
+
exp_param_path = f"{args.output_path[:-4]}.npy"
|
168 |
+
generate_expression_params(
|
169 |
+
cfg,
|
170 |
+
args.audio_path,
|
171 |
+
args.style_clip_path,
|
172 |
+
args.pose_path,
|
173 |
+
exp_param_path,
|
174 |
+
content_encoder,
|
175 |
+
style_encoder,
|
176 |
+
decoder,
|
177 |
+
)
|
178 |
+
|
179 |
+
image_renderer = get_netG(args.renderer_checkpoint)
|
180 |
+
render_video(
|
181 |
+
image_renderer,
|
182 |
+
args.src_img_path,
|
183 |
+
exp_param_path,
|
184 |
+
args.wav_path,
|
185 |
+
args.output_path,
|
186 |
+
split_size=4,
|
187 |
+
)
|
media/first_page.png
ADDED
![]() |
Git LFS Details
|
phindex.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"AA": 0, "AE": 1, "AH": 2, "AO": 3, "AW": 4, "AY": 5, "B": 6, "CH": 7, "D": 8, "DH": 9, "EH": 10, "ER": 11, "EY": 12, "F": 13, "G": 14, "HH": 15, "IH": 16, "IY": 17, "JH": 18, "K": 19, "L": 20, "M": 21, "N": 22, "NG": 23, "NSN": 24, "OW": 25, "OY": 26, "P": 27, "R": 28, "S": 29, "SH": 30, "SIL": 31, "T": 32, "TH": 33, "UH": 34, "UW": 35, "V": 36, "W": 37, "Y": 38, "Z": 39, "ZH": 40}
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
yacs==0.1.8
|
2 |
+
scipy==1.7.3
|
3 |
+
scikit-image==0.19.3
|
4 |
+
scikit-learn==1.0.2
|
5 |
+
PyYAML==6.0
|
6 |
+
Pillow==9.1.0
|
7 |
+
numpy==1.21.5
|
8 |
+
opencv-python==4.4.0.46
|
9 |
+
imageio==2.18.0
|
10 |
+
ffmpeg-python==0.2.0
|
11 |
+
av==10.0.0
|
samples/source_video/3DMM/KristiNoem.mat
ADDED
Binary file (566 kB). View file
|
|
samples/source_video/3DMM/Obama_clip1.mat
ADDED
Binary file (629 kB). View file
|
|
samples/source_video/3DMM/Obama_clip2.mat
ADDED
Binary file (943 kB). View file
|
|
samples/source_video/3DMM/Obama_clip3.mat
ADDED
Binary file (629 kB). View file
|
|