ameerazam08 commited on
Commit
9a973f2
·
1 Parent(s): 8f7f7c3

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. README.md +70 -0
  3. checkpoints/renderer_checkpoint.pt +3 -0
  4. checkpoints/styletalk_checkpoint.pth +3 -0
  5. configs/__pycache__/default.cpython-37.pyc +0 -0
  6. configs/default.py +65 -0
  7. configs/renderer_conf.yaml +17 -0
  8. core/__pycache__/utils.cpython-37.pyc +0 -0
  9. core/networks/__init__.py +9 -0
  10. core/networks/__pycache__/__init__.cpython-37.pyc +0 -0
  11. core/networks/__pycache__/disentangle_decoder.cpython-37.pyc +0 -0
  12. core/networks/__pycache__/dynamic_conv.cpython-37.pyc +0 -0
  13. core/networks/__pycache__/dynamic_fc_decoder.cpython-37.pyc +0 -0
  14. core/networks/__pycache__/dynamic_linear.cpython-37.pyc +0 -0
  15. core/networks/__pycache__/generator.cpython-37.pyc +0 -0
  16. core/networks/__pycache__/mish.cpython-37.pyc +0 -0
  17. core/networks/__pycache__/self_attention_pooling.cpython-37.pyc +0 -0
  18. core/networks/__pycache__/styletalk.cpython-37.pyc +0 -0
  19. core/networks/__pycache__/transformer.cpython-37.pyc +0 -0
  20. core/networks/building_blocks.py +112 -0
  21. core/networks/disentangle_decoder.py +184 -0
  22. core/networks/dynamic_conv.py +149 -0
  23. core/networks/dynamic_fc_decoder.py +140 -0
  24. core/networks/dynamic_linear.py +42 -0
  25. core/networks/generator.py +213 -0
  26. core/networks/mish.py +51 -0
  27. core/networks/self_attention_pooling.py +43 -0
  28. core/networks/styletalk.py +24 -0
  29. core/networks/transformer.py +300 -0
  30. core/utils.py +228 -0
  31. demo.mp4 +0 -0
  32. demo.npy +3 -0
  33. demo_download.mp4 +0 -0
  34. demo_download.npy +3 -0
  35. env.yaml +0 -0
  36. environment.yml +91 -0
  37. generators/__pycache__/base_function.cpython-37.pyc +0 -0
  38. generators/__pycache__/face_model.cpython-37.pyc +0 -0
  39. generators/__pycache__/flow_util.cpython-37.pyc +0 -0
  40. generators/base_function.py +368 -0
  41. generators/face_model.py +127 -0
  42. generators/flow_util.py +56 -0
  43. inference_for_demo.py +187 -0
  44. media/first_page.png +3 -0
  45. phindex.json +1 -0
  46. requirements.txt +11 -0
  47. samples/source_video/3DMM/KristiNoem.mat +0 -0
  48. samples/source_video/3DMM/Obama_clip1.mat +0 -0
  49. samples/source_video/3DMM/Obama_clip2.mat +0 -0
  50. samples/source_video/3DMM/Obama_clip3.mat +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ media/first_page.png filter=lfs diff=lfs merge=lfs -text
37
+ samples/source_video/wav/intro.wav filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # StyleTalk
3
+
4
+ The official repository of the AAAI2023 paper [StyleTalk: One-shot Talking Head Generation with Controllable Speaking Styles](https://arxiv.org/abs/2301.01081)
5
+
6
+ <p align='center'>
7
+ <b>
8
+ <a href="https://arxiv.org/abs/2301.01081">Paper</a>
9
+ |
10
+ <a href="https://drive.google.com/file/d/19WRhBHYVWRIH8_zo332l00fLXfUE96-k/view?usp=share_link">Supp. Materials</a>
11
+ |
12
+ <a href="https://youtu.be/mO2Tjcwr4u8">Video</a>
13
+ </b>
14
+ </p>
15
+
16
+ <p align='center'>
17
+ <img src='media/first_page.png' width='700'/>
18
+ </p>
19
+
20
+ The proposed **StyleTalk** can generate talking head videos with speaking styles specified by arbitrary style reference videos.
21
+
22
+ # News
23
+ * April 14th, 2023. The code is available.
24
+
25
+ # Get Started
26
+
27
+ ## Installation
28
+
29
+ Clone this repo, install conda and run:
30
+
31
+ ```bash
32
+ conda create -n styletalk python=3.7.0
33
+ conda activate styletalk
34
+ pip install -r requirements.txt
35
+ conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
36
+ conda update ffmpeg
37
+ ```
38
+
39
+ The code has been test on CUDA 11.1, GPU RTX 3090.
40
+
41
+ ## Data Preprocessing
42
+ Our methods takes 3DMM parameters(\*.mat) and phoneme labels(\*_seq.json) as input. Follow [PIRenderer](https://github.com/RenYurui/PIRender) to extract 3DMM parameters. Follow [AVCT](https://github.com/FuxiVirtualHuman/AAAI22-one-shot-talking-face) to extract phoneme labels. Some preprocessed data can be found in folder `samples`.
43
+
44
+
45
+ ## Inference
46
+ Download checkpoints for [StyleTalk](https://drive.google.com/file/d/1z54FymEiyPQ0mPGrVePt8GMtDe-E2RmN/view?usp=share_link) and [Renderer](https://drive.google.com/file/d/1wFAtFQjybKI3hwRWvtcBDl4tpZzlDkja/view?usp=share_link) and put them into `./checkpoints`.
47
+
48
+ Run the demo:
49
+
50
+ ```bash
51
+ python inference_for_demo.py \
52
+ --audio_path samples/source_video/phoneme/reagan_clip1_seq.json \
53
+ --style_clip_path samples/style_clips/3DMM/happyenglish_clip1.mat \
54
+ --pose_path samples/source_video/3DMM/reagan_clip1.mat \
55
+ --src_img_path samples/source_video/image/andrew_clip_1.png \
56
+ --wav_path samples/source_video/wav/reagan_clip1.wav \
57
+ --output_path demo.mp4
58
+ ```
59
+
60
+ Change `audio_path`, `style_clip_path`, `pose_path`, `src_img_path`, `wav_path`, `output_path` to generate more results.
61
+
62
+ # Acknowledgement
63
+ Some code are borrowed from following projects:
64
+ * [AVCT](https://github.com/FuxiVirtualHuman/AAAI22-one-shot-talking-face)
65
+ * [PIRenderer](https://github.com/RenYurui/PIRender)
66
+ * [Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch)
67
+ * [Speech Drives Templates](https://github.com/ShenhanQian/SpeechDrivesTemplates)
68
+ * [FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing)
69
+
70
+ Thanks for their contributions!
checkpoints/renderer_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a67014839d42d592255c9fc3b3ceecbcd62c27ce0c0a89ed6628292447404242
3
+ size 335281551
checkpoints/styletalk_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3bd52d0080d52440e25de44930b539d21c7f1431102c2811fafb30838e9812e
3
+ size 275485145
configs/__pycache__/default.cpython-37.pyc ADDED
Binary file (1.92 kB). View file
 
configs/default.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+
4
+ _C = CN()
5
+ _C.TAG = "style_id_emotion"
6
+ _C.DECODER_TYPE = "DisentangleDecoder"
7
+ _C.CONTENT_ENCODER_TYPE = "ContentEncoder"
8
+ _C.STYLE_ENCODER_TYPE = "StyleEncoder"
9
+ _C.DISCRIMINATOR_TYPE = "Discriminator"
10
+
11
+
12
+ _C.WIN_SIZE = 5
13
+ _C.D_MODEL = 256
14
+
15
+ _C.DATASET = CN()
16
+ _C.DATASET.FACE3D_DIM = 64
17
+
18
+ _C.CONTENT_ENCODER = CN()
19
+ _C.CONTENT_ENCODER.d_model = _C.D_MODEL
20
+ _C.CONTENT_ENCODER.nhead = 8
21
+ _C.CONTENT_ENCODER.num_encoder_layers = 3
22
+ _C.CONTENT_ENCODER.dim_feedforward = 4 * _C.D_MODEL
23
+ _C.CONTENT_ENCODER.dropout = 0.1
24
+ _C.CONTENT_ENCODER.activation = "relu"
25
+ _C.CONTENT_ENCODER.normalize_before = False
26
+ _C.CONTENT_ENCODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
27
+ _C.CONTENT_ENCODER.ph_embed_dim = 128
28
+
29
+ _C.STYLE_ENCODER = CN()
30
+ _C.STYLE_ENCODER.d_model = _C.D_MODEL
31
+ _C.STYLE_ENCODER.nhead = 8
32
+ _C.STYLE_ENCODER.num_encoder_layers = 3
33
+ _C.STYLE_ENCODER.dim_feedforward = 4 * _C.D_MODEL
34
+ _C.STYLE_ENCODER.dropout = 0.1
35
+ _C.STYLE_ENCODER.activation = "relu"
36
+ _C.STYLE_ENCODER.normalize_before = False
37
+ _C.STYLE_ENCODER.pos_embed_len = 256
38
+ _C.STYLE_ENCODER.aggregate_method = "self_attention_pooling" # average | self_attention_pooling
39
+ # _C.STYLE_ENCODER.input_dim = _C.DATASET.FACE3D_DIM
40
+
41
+ _C.DECODER = CN()
42
+ _C.DECODER.d_model = _C.D_MODEL
43
+ _C.DECODER.nhead = 8
44
+ _C.DECODER.num_decoder_layers = 3
45
+ _C.DECODER.dim_feedforward = 4 * _C.D_MODEL
46
+ _C.DECODER.dropout = 0.1
47
+ _C.DECODER.activation = "relu"
48
+ _C.DECODER.normalize_before = False
49
+ _C.DECODER.return_intermediate_dec = False
50
+ _C.DECODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
51
+ _C.DECODER.network_type = "DynamicFCDecoder"
52
+ _C.DECODER.dynamic_K = 8
53
+ _C.DECODER.dynamic_ratio = 4
54
+ # fmt: off
55
+ _C.DECODER.upper_face3d_indices = [6, 8, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
56
+ # fmt: on
57
+ _C.DECODER.lower_face3d_indices = [0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14]
58
+
59
+ _C.INFERENCE = CN()
60
+ _C.INFERENCE.CHECKPOINT = ""
61
+
62
+
63
+ def get_cfg_defaults():
64
+ """Get a yacs CfgNode object with default values for my_project."""
65
+ return _C.clone()
configs/renderer_conf.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ common:
2
+ descriptor_nc: 256
3
+ image_nc: 3
4
+ max_nc: 256
5
+ use_spect: false
6
+ editing_net:
7
+ base_nc: 64
8
+ layer: 3
9
+ num_res_blocks: 2
10
+ mapping_net:
11
+ coeff_nc: 73
12
+ descriptor_nc: 256
13
+ layer: 3
14
+ warpping_net:
15
+ base_nc: 32
16
+ decoder_layer: 3
17
+ encoder_layer: 5
core/__pycache__/utils.cpython-37.pyc ADDED
Binary file (6.26 kB). View file
 
core/networks/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from core.networks.generator import ContentEncoder, StyleEncoder, Decoder
2
+ from core.networks.disentangle_decoder import DisentangleDecoder
3
+
4
+ def get_network(name: str):
5
+ obj = globals().get(name)
6
+ if obj is None:
7
+ raise KeyError("Unknown Network: %s" % name)
8
+ else:
9
+ return obj
core/networks/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (533 Bytes). View file
 
core/networks/__pycache__/disentangle_decoder.cpython-37.pyc ADDED
Binary file (3.17 kB). View file
 
core/networks/__pycache__/dynamic_conv.cpython-37.pyc ADDED
Binary file (3.78 kB). View file
 
core/networks/__pycache__/dynamic_fc_decoder.cpython-37.pyc ADDED
Binary file (3.16 kB). View file
 
core/networks/__pycache__/dynamic_linear.cpython-37.pyc ADDED
Binary file (1.3 kB). View file
 
core/networks/__pycache__/generator.cpython-37.pyc ADDED
Binary file (4.85 kB). View file
 
core/networks/__pycache__/mish.cpython-37.pyc ADDED
Binary file (1.7 kB). View file
 
core/networks/__pycache__/self_attention_pooling.cpython-37.pyc ADDED
Binary file (1.61 kB). View file
 
core/networks/__pycache__/styletalk.cpython-37.pyc ADDED
Binary file (998 Bytes). View file
 
core/networks/__pycache__/transformer.cpython-37.pyc ADDED
Binary file (9.3 kB). View file
 
core/networks/building_blocks.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class ADAIN(nn.Module):
5
+ def __init__(self, content_nc, condition_nc, hidden_nc):
6
+ super().__init__()
7
+
8
+ self.param_free_norm = nn.InstanceNorm1d(content_nc, affine=False)
9
+
10
+ use_bias = True
11
+
12
+ self.mlp_shared = nn.Sequential(
13
+ nn.Linear(condition_nc, hidden_nc, bias=use_bias),
14
+ nn.ReLU(),
15
+ )
16
+ self.mlp_gamma = nn.Linear(hidden_nc, content_nc, bias=use_bias)
17
+ self.mlp_beta = nn.Linear(hidden_nc, content_nc, bias=use_bias)
18
+
19
+ def forward(self, content, condition):
20
+
21
+ # Part 1. generate parameter-free normalized activations
22
+ normalized = self.param_free_norm(content)
23
+
24
+ # Part 2. produce scaling and bias conditioned on feature
25
+ actv = self.mlp_shared(condition)
26
+ gamma = self.mlp_gamma(actv)
27
+ beta = self.mlp_beta(actv)
28
+
29
+ # apply scale and bias
30
+ gamma = gamma.unsqueeze(-1)
31
+ beta = beta.unsqueeze(-1)
32
+ out = normalized * (1 + gamma) + beta
33
+ return out
34
+
35
+
36
+ class ConvNormRelu(nn.Module):
37
+ def __init__(
38
+ self,
39
+ conv_type="1d",
40
+ in_channels=3,
41
+ out_channels=64,
42
+ downsample=False,
43
+ kernel_size=None,
44
+ stride=None,
45
+ padding=None,
46
+ norm="IN",
47
+ leaky=False,
48
+ adain_condition_nc=None,
49
+ adain_hidden_nc=None,
50
+ ):
51
+ super().__init__()
52
+ if kernel_size is None:
53
+ if downsample:
54
+ kernel_size, stride, padding = 4, 2, 1
55
+ else:
56
+ kernel_size, stride, padding = 3, 1, 1
57
+
58
+ if conv_type == "1d":
59
+ self.conv = nn.Conv1d(
60
+ in_channels,
61
+ out_channels,
62
+ kernel_size,
63
+ stride,
64
+ padding,
65
+ bias=False,
66
+ )
67
+ if norm == "IN":
68
+ self.norm = nn.InstanceNorm1d(out_channels, affine=True)
69
+ elif norm == "ADAIN":
70
+ self.norm = ADAIN(out_channels, adain_condition_nc, adain_hidden_nc)
71
+ elif norm == "NONE":
72
+ self.norm = nn.Identity()
73
+ else:
74
+ raise NotImplementedError
75
+ nn.init.kaiming_normal_(self.conv.weight)
76
+
77
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) if leaky else nn.ReLU(inplace=True)
78
+
79
+ def forward(self, x, condition=None):
80
+ """
81
+
82
+ Args:
83
+ x (_type_): (B, C, L)
84
+ condition (_type_, optional): (B, C)
85
+
86
+ Returns:
87
+ _type_: _description_
88
+ """
89
+ x = self.conv(x)
90
+ if isinstance(self.norm, ADAIN):
91
+ x = self.norm(x, condition)
92
+ else:
93
+ x = self.norm(x)
94
+ x = self.act(x)
95
+ return x
96
+
97
+
98
+ class MyConv1d(nn.Module):
99
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
100
+ super().__init__(*args, **kwargs)
101
+ self.conv_block = nn.Sequential(
102
+ nn.Conv1d(cin, cout, kernel_size, stride, padding),
103
+ nn.BatchNorm1d(cout),
104
+ )
105
+ self.act = nn.ReLU()
106
+ self.residual = residual
107
+
108
+ def forward(self, x):
109
+ out = self.conv_block(x)
110
+ if self.residual:
111
+ out += x
112
+ return self.act(out)
core/networks/disentangle_decoder.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .transformer import (
5
+ PositionalEncoding,
6
+ TransformerDecoderLayer,
7
+ TransformerDecoder,
8
+ )
9
+ from core.networks.dynamic_fc_decoder import DynamicFCDecoderLayer, DynamicFCDecoder
10
+ from core.utils import _reset_parameters
11
+
12
+
13
+ def get_decoder_network(
14
+ network_type,
15
+ d_model,
16
+ nhead,
17
+ dim_feedforward,
18
+ dropout,
19
+ activation,
20
+ normalize_before,
21
+ num_decoder_layers,
22
+ return_intermediate_dec,
23
+ dynamic_K,
24
+ dynamic_ratio,
25
+ ):
26
+ decoder = None
27
+ if network_type == "TransformerDecoder":
28
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
29
+ norm = nn.LayerNorm(d_model)
30
+ decoder = TransformerDecoder(
31
+ decoder_layer,
32
+ num_decoder_layers,
33
+ norm,
34
+ return_intermediate_dec,
35
+ )
36
+ elif network_type == "DynamicFCDecoder":
37
+ d_style = d_model
38
+ decoder_layer = DynamicFCDecoderLayer(
39
+ d_model,
40
+ nhead,
41
+ d_style,
42
+ dynamic_K,
43
+ dynamic_ratio,
44
+ dim_feedforward,
45
+ dropout,
46
+ activation,
47
+ normalize_before,
48
+ )
49
+ norm = nn.LayerNorm(d_model)
50
+ decoder = DynamicFCDecoder(decoder_layer, num_decoder_layers, norm, return_intermediate_dec)
51
+ else:
52
+ raise ValueError(f"Invalid network_type {network_type}")
53
+
54
+ return decoder
55
+
56
+
57
+ class DisentangleDecoder(nn.Module):
58
+ def __init__(
59
+ self,
60
+ d_model=512,
61
+ nhead=8,
62
+ num_decoder_layers=3,
63
+ dim_feedforward=2048,
64
+ dropout=0.1,
65
+ activation="relu",
66
+ normalize_before=False,
67
+ return_intermediate_dec=False,
68
+ pos_embed_len=80,
69
+ upper_face3d_indices=tuple(list(range(19)) + list(range(46, 51))),
70
+ lower_face3d_indices=tuple(range(19, 46)),
71
+ network_type="None",
72
+ dynamic_K=None,
73
+ dynamic_ratio=None,
74
+ **_,
75
+ ) -> None:
76
+ super().__init__()
77
+
78
+ self.upper_face3d_indices = upper_face3d_indices
79
+ self.lower_face3d_indices = lower_face3d_indices
80
+
81
+ # upper_decoder_layer = TransformerDecoderLayer(
82
+ # d_model, nhead, dim_feedforward, dropout, activation, normalize_before
83
+ # )
84
+ # upper_decoder_norm = nn.LayerNorm(d_model)
85
+ # self.upper_decoder = TransformerDecoder(
86
+ # upper_decoder_layer,
87
+ # num_decoder_layers,
88
+ # upper_decoder_norm,
89
+ # return_intermediate=return_intermediate_dec,
90
+ # )
91
+ self.upper_decoder = get_decoder_network(
92
+ network_type,
93
+ d_model,
94
+ nhead,
95
+ dim_feedforward,
96
+ dropout,
97
+ activation,
98
+ normalize_before,
99
+ num_decoder_layers,
100
+ return_intermediate_dec,
101
+ dynamic_K,
102
+ dynamic_ratio,
103
+ )
104
+ _reset_parameters(self.upper_decoder)
105
+
106
+ # lower_decoder_layer = TransformerDecoderLayer(
107
+ # d_model, nhead, dim_feedforward, dropout, activation, normalize_before
108
+ # )
109
+ # lower_decoder_norm = nn.LayerNorm(d_model)
110
+ # self.lower_decoder = TransformerDecoder(
111
+ # lower_decoder_layer,
112
+ # num_decoder_layers,
113
+ # lower_decoder_norm,
114
+ # return_intermediate=return_intermediate_dec,
115
+ # )
116
+ self.lower_decoder = get_decoder_network(
117
+ network_type,
118
+ d_model,
119
+ nhead,
120
+ dim_feedforward,
121
+ dropout,
122
+ activation,
123
+ normalize_before,
124
+ num_decoder_layers,
125
+ return_intermediate_dec,
126
+ dynamic_K,
127
+ dynamic_ratio,
128
+ )
129
+ _reset_parameters(self.lower_decoder)
130
+
131
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
132
+
133
+ tail_hidden_dim = d_model // 2
134
+ self.upper_tail_fc = nn.Sequential(
135
+ nn.Linear(d_model, tail_hidden_dim),
136
+ nn.ReLU(),
137
+ nn.Linear(tail_hidden_dim, tail_hidden_dim),
138
+ nn.ReLU(),
139
+ nn.Linear(tail_hidden_dim, len(upper_face3d_indices)),
140
+ )
141
+ self.lower_tail_fc = nn.Sequential(
142
+ nn.Linear(d_model, tail_hidden_dim),
143
+ nn.ReLU(),
144
+ nn.Linear(tail_hidden_dim, tail_hidden_dim),
145
+ nn.ReLU(),
146
+ nn.Linear(tail_hidden_dim, len(lower_face3d_indices)),
147
+ )
148
+
149
+ def forward(self, content, style_code):
150
+ """
151
+
152
+ Args:
153
+ content (_type_): (B, num_frames, window, C_dmodel)
154
+ style_code (_type_): (B, C_dmodel)
155
+
156
+ Returns:
157
+ face3d: (B, L_clip, C_3dmm)
158
+ """
159
+ B, N, W, C = content.shape
160
+ style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
161
+ style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
162
+ # (W, B*N, C)
163
+
164
+ content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
165
+ # (W, B*N, C)
166
+ tgt = torch.zeros_like(style)
167
+ pos_embed = self.pos_embed(W)
168
+ pos_embed = pos_embed.permute(1, 0, 2)
169
+
170
+ upper_face3d_feat = self.upper_decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
171
+ # (W, B*N, C)
172
+ upper_face3d_feat = upper_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
173
+ # (B, N, C)
174
+ upper_face3d = self.upper_tail_fc(upper_face3d_feat)
175
+ # (B, N, C_exp)
176
+
177
+ lower_face3d_feat = self.lower_decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
178
+ lower_face3d_feat = lower_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
179
+ lower_face3d = self.lower_tail_fc(lower_face3d_feat)
180
+ C_exp = len(self.upper_face3d_indices) + len(self.lower_face3d_indices)
181
+ face3d = torch.zeros(B, N, C_exp).to(upper_face3d)
182
+ face3d[:, :, self.upper_face3d_indices] = upper_face3d
183
+ face3d[:, :, self.lower_face3d_indices] = lower_face3d
184
+ return face3d
core/networks/dynamic_conv.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class Attention(nn.Module):
9
+ def __init__(self, cond_planes, ratio, K, temperature=30, init_weight=True):
10
+ super().__init__()
11
+ # self.avgpool = nn.AdaptiveAvgPool2d(1)
12
+ self.temprature = temperature
13
+ assert cond_planes > ratio
14
+ hidden_planes = cond_planes // ratio
15
+ self.net = nn.Sequential(
16
+ nn.Conv2d(cond_planes, hidden_planes, kernel_size=1, bias=False),
17
+ nn.ReLU(),
18
+ nn.Conv2d(hidden_planes, K, kernel_size=1, bias=False),
19
+ )
20
+
21
+ if init_weight:
22
+ self._initialize_weights()
23
+
24
+ def update_temprature(self):
25
+ if self.temprature > 1:
26
+ self.temprature -= 1
27
+
28
+ def _initialize_weights(self):
29
+ for m in self.modules():
30
+ if isinstance(m, nn.Conv2d):
31
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
32
+ if m.bias is not None:
33
+ nn.init.constant_(m.bias, 0)
34
+ if isinstance(m, nn.BatchNorm2d):
35
+ nn.init.constant_(m.weight, 1)
36
+ nn.init.constant_(m.bias, 0)
37
+
38
+ def forward(self, cond):
39
+ """
40
+
41
+ Args:
42
+ cond (_type_): (B, C_style)
43
+
44
+ Returns:
45
+ _type_: (B, K)
46
+ """
47
+
48
+ # att = self.avgpool(cond) # bs,dim,1,1
49
+ att = cond.view(cond.shape[0], cond.shape[1], 1, 1)
50
+ att = self.net(att).view(cond.shape[0], -1) # bs,K
51
+ return F.softmax(att / self.temprature, -1)
52
+
53
+
54
+ class DynamicConv(nn.Module):
55
+ def __init__(
56
+ self,
57
+ in_planes,
58
+ out_planes,
59
+ cond_planes,
60
+ kernel_size,
61
+ stride,
62
+ padding=0,
63
+ dilation=1,
64
+ groups=1,
65
+ bias=True,
66
+ K=4,
67
+ temperature=30,
68
+ ratio=4,
69
+ init_weight=True,
70
+ ):
71
+ super().__init__()
72
+ self.in_planes = in_planes
73
+ self.out_planes = out_planes
74
+ self.cond_planes = cond_planes
75
+ self.kernel_size = kernel_size
76
+ self.stride = stride
77
+ self.padding = padding
78
+ self.dilation = dilation
79
+ self.groups = groups
80
+ self.bias = bias
81
+ self.K = K
82
+ self.init_weight = init_weight
83
+ self.attention = Attention(
84
+ cond_planes=cond_planes, ratio=ratio, K=K, temperature=temperature, init_weight=init_weight
85
+ )
86
+
87
+ self.weight = nn.Parameter(
88
+ torch.randn(K, out_planes, in_planes // groups, kernel_size, kernel_size), requires_grad=True
89
+ )
90
+ if bias:
91
+ self.bias = nn.Parameter(torch.randn(K, out_planes), requires_grad=True)
92
+ else:
93
+ self.bias = None
94
+
95
+ if self.init_weight:
96
+ self._initialize_weights()
97
+
98
+ def _initialize_weights(self):
99
+ for i in range(self.K):
100
+ nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
101
+ if self.bias is not None:
102
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
103
+ if fan_in != 0:
104
+ bound = 1 / math.sqrt(fan_in)
105
+ nn.init.uniform_(self.bias, -bound, bound)
106
+
107
+ def forward(self, x, cond):
108
+ """
109
+
110
+ Args:
111
+ x (_type_): (B, C_in, L, 1)
112
+ cond (_type_): (B, C_style)
113
+
114
+ Returns:
115
+ _type_: (B, C_out, L, 1)
116
+ """
117
+ bs, in_planels, h, w = x.shape
118
+ softmax_att = self.attention(cond) # bs,K
119
+ x = x.view(1, -1, h, w)
120
+ weight = self.weight.view(self.K, -1) # K,-1
121
+ aggregate_weight = torch.mm(softmax_att, weight).view(
122
+ bs * self.out_planes, self.in_planes // self.groups, self.kernel_size, self.kernel_size
123
+ ) # bs*out_p,in_p,k,k
124
+
125
+ if self.bias is not None:
126
+ bias = self.bias.view(self.K, -1) # K,out_p
127
+ aggregate_bias = torch.mm(softmax_att, bias).view(-1) # bs*out_p
128
+ output = F.conv2d(
129
+ x, # 1, bs*in_p, L, 1
130
+ weight=aggregate_weight,
131
+ bias=aggregate_bias,
132
+ stride=self.stride,
133
+ padding=self.padding,
134
+ groups=self.groups * bs,
135
+ dilation=self.dilation,
136
+ )
137
+ else:
138
+ output = F.conv2d(
139
+ x,
140
+ weight=aggregate_weight,
141
+ bias=None,
142
+ stride=self.stride,
143
+ padding=self.padding,
144
+ groups=self.groups * bs,
145
+ dilation=self.dilation,
146
+ )
147
+
148
+ output = output.view(bs, self.out_planes, h, w)
149
+ return output
core/networks/dynamic_fc_decoder.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ from core.networks.transformer import _get_activation_fn, _get_clones
5
+ from core.networks.dynamic_linear import DynamicLinear
6
+
7
+
8
+ class DynamicFCDecoderLayer(nn.Module):
9
+ def __init__(
10
+ self,
11
+ d_model,
12
+ nhead,
13
+ d_style,
14
+ dynamic_K,
15
+ dynamic_ratio,
16
+ dim_feedforward=2048,
17
+ dropout=0.1,
18
+ activation="relu",
19
+ normalize_before=False,
20
+ ):
21
+ super().__init__()
22
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
23
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
24
+ # Implementation of Feedforward model
25
+ # self.linear1 = nn.Linear(d_model, dim_feedforward)
26
+ self.linear1 = DynamicLinear(d_model, dim_feedforward, d_style, K=dynamic_K, ratio=dynamic_ratio)
27
+ self.dropout = nn.Dropout(dropout)
28
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
29
+ # self.linear2 = DynamicLinear(dim_feedforward, d_model, d_style, K=dynamic_K, ratio=dynamic_ratio)
30
+
31
+ self.norm1 = nn.LayerNorm(d_model)
32
+ self.norm2 = nn.LayerNorm(d_model)
33
+ self.norm3 = nn.LayerNorm(d_model)
34
+ self.dropout1 = nn.Dropout(dropout)
35
+ self.dropout2 = nn.Dropout(dropout)
36
+ self.dropout3 = nn.Dropout(dropout)
37
+
38
+ self.activation = _get_activation_fn(activation)
39
+ self.normalize_before = normalize_before
40
+
41
+ def with_pos_embed(self, tensor, pos):
42
+ return tensor if pos is None else tensor + pos
43
+
44
+ def forward_post(
45
+ self,
46
+ tgt,
47
+ memory,
48
+ style,
49
+ tgt_mask=None,
50
+ memory_mask=None,
51
+ tgt_key_padding_mask=None,
52
+ memory_key_padding_mask=None,
53
+ pos=None,
54
+ query_pos=None,
55
+ ):
56
+ # q = k = self.with_pos_embed(tgt, query_pos)
57
+ tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
58
+ tgt = tgt + self.dropout1(tgt2)
59
+ tgt = self.norm1(tgt)
60
+ tgt2 = self.multihead_attn(
61
+ query=tgt, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
62
+ )[0]
63
+ tgt = tgt + self.dropout2(tgt2)
64
+ tgt = self.norm2(tgt)
65
+ # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))), style)
66
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))))
67
+ tgt = tgt + self.dropout3(tgt2)
68
+ tgt = self.norm3(tgt)
69
+ return tgt
70
+
71
+ def forward(
72
+ self,
73
+ tgt,
74
+ memory,
75
+ style,
76
+ tgt_mask=None,
77
+ memory_mask=None,
78
+ tgt_key_padding_mask=None,
79
+ memory_key_padding_mask=None,
80
+ pos=None,
81
+ query_pos=None,
82
+ ):
83
+ if self.normalize_before:
84
+ raise NotImplementedError
85
+
86
+ return self.forward_post(
87
+ tgt, memory, style, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
88
+ )
89
+
90
+
91
+ class DynamicFCDecoder(nn.Module):
92
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
93
+ super().__init__()
94
+ self.layers = _get_clones(decoder_layer, num_layers)
95
+ self.num_layers = num_layers
96
+ self.norm = norm
97
+ self.return_intermediate = return_intermediate
98
+
99
+ def forward(
100
+ self,
101
+ tgt,
102
+ memory,
103
+ tgt_mask=None,
104
+ memory_mask=None,
105
+ tgt_key_padding_mask=None,
106
+ memory_key_padding_mask=None,
107
+ pos=None,
108
+ query_pos=None,
109
+ ):
110
+ style = query_pos[0]
111
+ # (B*N, C)
112
+ output = tgt + pos + query_pos
113
+
114
+ intermediate = []
115
+
116
+ for layer in self.layers:
117
+ output = layer(
118
+ output,
119
+ memory,
120
+ style,
121
+ tgt_mask=tgt_mask,
122
+ memory_mask=memory_mask,
123
+ tgt_key_padding_mask=tgt_key_padding_mask,
124
+ memory_key_padding_mask=memory_key_padding_mask,
125
+ pos=pos,
126
+ query_pos=query_pos,
127
+ )
128
+ if self.return_intermediate:
129
+ intermediate.append(self.norm(output))
130
+
131
+ if self.norm is not None:
132
+ output = self.norm(output)
133
+ if self.return_intermediate:
134
+ intermediate.pop()
135
+ intermediate.append(output)
136
+
137
+ if self.return_intermediate:
138
+ return torch.stack(intermediate)
139
+
140
+ return output.unsqueeze(0)
core/networks/dynamic_linear.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from core.networks.dynamic_conv import DynamicConv
8
+
9
+
10
+ class DynamicLinear(nn.Module):
11
+ def __init__(self, in_planes, out_planes, cond_planes, bias=True, K=4, temperature=30, ratio=4, init_weight=True):
12
+ super().__init__()
13
+
14
+ self.dynamic_conv = DynamicConv(
15
+ in_planes,
16
+ out_planes,
17
+ cond_planes,
18
+ kernel_size=1,
19
+ stride=1,
20
+ padding=0,
21
+ bias=bias,
22
+ K=K,
23
+ ratio=ratio,
24
+ temperature=temperature,
25
+ init_weight=init_weight,
26
+ )
27
+
28
+ def forward(self, x, cond):
29
+ """
30
+
31
+ Args:
32
+ x (_type_): (L, B, C_in)
33
+ cond (_type_): (B, C_style)
34
+
35
+ Returns:
36
+ _type_: (L, B, C_out)
37
+ """
38
+ x = x.permute(1, 2, 0).unsqueeze(-1)
39
+ out = self.dynamic_conv(x, cond)
40
+ # (B, C_out, L, 1)
41
+ out = out.squeeze().permute(2, 0, 1)
42
+ return out
core/networks/generator.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .transformer import (
5
+ TransformerEncoder,
6
+ TransformerEncoderLayer,
7
+ PositionalEncoding,
8
+ TransformerDecoderLayer,
9
+ TransformerDecoder,
10
+ )
11
+ from core.utils import _reset_parameters
12
+ from core.networks.self_attention_pooling import SelfAttentionPooling
13
+
14
+
15
+ class ContentEncoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ d_model=512,
19
+ nhead=8,
20
+ num_encoder_layers=6,
21
+ dim_feedforward=2048,
22
+ dropout=0.1,
23
+ activation="relu",
24
+ normalize_before=False,
25
+ pos_embed_len=80,
26
+ ph_embed_dim=128,
27
+ ):
28
+ super().__init__()
29
+
30
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
31
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
32
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
33
+
34
+ _reset_parameters(self.encoder)
35
+
36
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
37
+
38
+ self.ph_embedding = nn.Embedding(41, ph_embed_dim)
39
+ self.increase_embed_dim = nn.Linear(ph_embed_dim, d_model)
40
+
41
+ def forward(self, x):
42
+ """
43
+
44
+ Args:
45
+ x (_type_): (B, num_frames, window)
46
+
47
+ Returns:
48
+ content: (B, num_frames, window, C_dmodel)
49
+ """
50
+ x_embedding = self.ph_embedding(x)
51
+ x_embedding = self.increase_embed_dim(x_embedding)
52
+ # (B, N, W, C)
53
+ B, N, W, C = x_embedding.shape
54
+ x_embedding = x_embedding.reshape(B * N, W, C)
55
+ x_embedding = x_embedding.permute(1, 0, 2)
56
+ # (W, B*N, C)
57
+
58
+ pos = self.pos_embed(W)
59
+ pos = pos.permute(1, 0, 2)
60
+ # (W, 1, C)
61
+
62
+ content = self.encoder(x_embedding, pos=pos)
63
+ # (W, B*N, C)
64
+ content = content.permute(1, 0, 2).reshape(B, N, W, C)
65
+ # (B, N, W, C)
66
+
67
+ return content
68
+
69
+
70
+ class StyleEncoder(nn.Module):
71
+ def __init__(
72
+ self,
73
+ d_model=512,
74
+ nhead=8,
75
+ num_encoder_layers=6,
76
+ dim_feedforward=2048,
77
+ dropout=0.1,
78
+ activation="relu",
79
+ normalize_before=False,
80
+ pos_embed_len=80,
81
+ input_dim=128,
82
+ aggregate_method="average",
83
+ ):
84
+ super().__init__()
85
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
86
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
87
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
88
+ _reset_parameters(self.encoder)
89
+
90
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
91
+
92
+ self.increase_embed_dim = nn.Linear(input_dim, d_model)
93
+
94
+ self.aggregate_method = None
95
+ if aggregate_method == "self_attention_pooling":
96
+ self.aggregate_method = SelfAttentionPooling(d_model)
97
+ elif aggregate_method == "average":
98
+ pass
99
+ else:
100
+ raise ValueError(f"Invalid aggregate method {aggregate_method}")
101
+
102
+ def forward(self, x, pad_mask=None):
103
+ """
104
+
105
+ Args:
106
+ x (_type_): (B, num_frames(L), C_exp)
107
+ pad_mask: (B, num_frames)
108
+
109
+ Returns:
110
+ style_code: (B, C_model)
111
+ """
112
+ x = self.increase_embed_dim(x)
113
+ # (B, L, C)
114
+ x = x.permute(1, 0, 2)
115
+ # (L, B, C)
116
+
117
+ pos = self.pos_embed(x.shape[0])
118
+ pos = pos.permute(1, 0, 2)
119
+ # (L, 1, C)
120
+
121
+ style = self.encoder(x, pos=pos, src_key_padding_mask=pad_mask)
122
+ # (L, B, C)
123
+
124
+ if self.aggregate_method is not None:
125
+ permute_style = style.permute(1, 0, 2)
126
+ # (B, L, C)
127
+ style_code = self.aggregate_method(permute_style, pad_mask)
128
+ return style_code
129
+
130
+ if pad_mask is None:
131
+ style = style.permute(1, 2, 0)
132
+ # (B, C, L)
133
+ style_code = style.mean(2)
134
+ # (B, C)
135
+ else:
136
+ permute_style = style.permute(1, 0, 2)
137
+ # (B, L, C)
138
+ permute_style[pad_mask] = 0
139
+ sum_style_code = permute_style.sum(dim=1)
140
+ # (B, C)
141
+ valid_token_num = (~pad_mask).sum(dim=1).unsqueeze(-1)
142
+ # (B, 1)
143
+ style_code = sum_style_code / valid_token_num
144
+ # (B, C)
145
+
146
+ return style_code
147
+
148
+
149
+ class Decoder(nn.Module):
150
+ def __init__(
151
+ self,
152
+ d_model=512,
153
+ nhead=8,
154
+ num_decoder_layers=3,
155
+ dim_feedforward=2048,
156
+ dropout=0.1,
157
+ activation="relu",
158
+ normalize_before=False,
159
+ return_intermediate_dec=False,
160
+ pos_embed_len=80,
161
+ output_dim=64,
162
+ **_,
163
+ ) -> None:
164
+ super().__init__()
165
+
166
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
167
+ decoder_norm = nn.LayerNorm(d_model)
168
+ self.decoder = TransformerDecoder(
169
+ decoder_layer,
170
+ num_decoder_layers,
171
+ decoder_norm,
172
+ return_intermediate=return_intermediate_dec,
173
+ )
174
+ _reset_parameters(self.decoder)
175
+
176
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
177
+
178
+ tail_hidden_dim = d_model // 2
179
+ self.tail_fc = nn.Sequential(
180
+ nn.Linear(d_model, tail_hidden_dim),
181
+ nn.ReLU(),
182
+ nn.Linear(tail_hidden_dim, tail_hidden_dim),
183
+ nn.ReLU(),
184
+ nn.Linear(tail_hidden_dim, output_dim),
185
+ )
186
+
187
+ def forward(self, content, style_code):
188
+ """
189
+
190
+ Args:
191
+ content (_type_): (B, num_frames, window, C_dmodel)
192
+ style_code (_type_): (B, C_dmodel)
193
+
194
+ Returns:
195
+ face3d: (B, num_frames, C_3dmm)
196
+ """
197
+ B, N, W, C = content.shape
198
+ style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
199
+ style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
200
+ # (W, B*N, C)
201
+
202
+ content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
203
+ # (W, B*N, C)
204
+ tgt = torch.zeros_like(style)
205
+ pos_embed = self.pos_embed(W)
206
+ pos_embed = pos_embed.permute(1, 0, 2)
207
+ face3d_feat = self.decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
208
+ # (W, B*N, C)
209
+ face3d_feat = face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
210
+ # (B, N, C)
211
+ face3d = self.tail_fc(face3d_feat)
212
+ # (B, N, C_exp)
213
+ return face3d
core/networks/mish.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Applies the mish function element-wise:
3
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
4
+ """
5
+
6
+ # import pytorch
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ @torch.jit.script
12
+ def mish(input):
13
+ """
14
+ Applies the mish function element-wise:
15
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
16
+ See additional documentation for mish class.
17
+ """
18
+ return input * torch.tanh(F.softplus(input))
19
+
20
+ class Mish(nn.Module):
21
+ """
22
+ Applies the mish function element-wise:
23
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
24
+
25
+ Shape:
26
+ - Input: (N, *) where * means, any number of additional
27
+ dimensions
28
+ - Output: (N, *), same shape as the input
29
+
30
+ Examples:
31
+ >>> m = Mish()
32
+ >>> input = torch.randn(2)
33
+ >>> output = m(input)
34
+
35
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
36
+ """
37
+
38
+ def __init__(self):
39
+ """
40
+ Init method.
41
+ """
42
+ super().__init__()
43
+
44
+ def forward(self, input):
45
+ """
46
+ Forward pass of the function.
47
+ """
48
+ if torch.__version__ >= "1.9":
49
+ return F.mish(input)
50
+ else:
51
+ return mish(input)
core/networks/self_attention_pooling.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from core.networks.mish import Mish
4
+
5
+
6
+ class SelfAttentionPooling(nn.Module):
7
+ """
8
+ Implementation of SelfAttentionPooling
9
+ Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
10
+ https://arxiv.org/pdf/2008.01077v1.pdf
11
+ """
12
+
13
+ def __init__(self, input_dim):
14
+ super(SelfAttentionPooling, self).__init__()
15
+ self.W = nn.Sequential(nn.Linear(input_dim, input_dim), Mish(), nn.Linear(input_dim, 1))
16
+ self.softmax = nn.functional.softmax
17
+
18
+ def forward(self, batch_rep, att_mask=None):
19
+ """
20
+ N: batch size, T: sequence length, H: Hidden dimension
21
+ input:
22
+ batch_rep : size (N, T, H)
23
+ attention_weight:
24
+ att_w : size (N, T, 1)
25
+ att_mask:
26
+ att_mask: size (N, T): if True, mask this item.
27
+ return:
28
+ utter_rep: size (N, H)
29
+ """
30
+
31
+ att_logits = self.W(batch_rep).squeeze(-1)
32
+ # (N, T)
33
+ if att_mask is not None:
34
+ att_mask_logits = att_mask.to(dtype=batch_rep.dtype) * -100000.0
35
+ # (N, T)
36
+ att_logits = att_mask_logits + att_logits
37
+
38
+ att_w = self.softmax(att_logits, dim=-1).unsqueeze(-1)
39
+ utter_rep = torch.sum(batch_rep * att_w, dim=1)
40
+
41
+ return utter_rep
42
+
43
+
core/networks/styletalk.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from core.networks import get_network
4
+
5
+
6
+ class StyleTalk(nn.Module):
7
+ def __init__(self, cfg) -> None:
8
+ super().__init__()
9
+ self.cfg = cfg
10
+
11
+ content_encoder_class = get_network(cfg.CONTENT_ENCODER_TYPE)
12
+ self.content_encoder = content_encoder_class(**cfg.CONTENT_ENCODER)
13
+
14
+ style_encoder_class = get_network(cfg.STYLE_ENCODER_TYPE)
15
+ cfg.defrost()
16
+ cfg.STYLE_ENCODER.input_dim = cfg.DATASET.FACE3D_DIM
17
+ cfg.freeze()
18
+ self.style_encoder = style_encoder_class(**cfg.STYLE_ENCODER)
19
+
20
+ decoder_class = get_network(cfg.DECODER_TYPE)
21
+ cfg.defrost()
22
+ cfg.DECODER.output_dim = cfg.DATASET.FACE3D_DIM
23
+ cfg.freeze()
24
+ self.decoder = decoder_class(**cfg.DECODER)
core/networks/transformer.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import copy
6
+
7
+
8
+ class PositionalEncoding(nn.Module):
9
+ def __init__(self, d_hid, n_position=200):
10
+ super(PositionalEncoding, self).__init__()
11
+
12
+ # Not a parameter
13
+ self.register_buffer("pos_table", self._get_sinusoid_encoding_table(n_position, d_hid))
14
+
15
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
16
+ """Sinusoid position encoding table"""
17
+ # TODO: make it with torch instead of numpy
18
+
19
+ def get_position_angle_vec(position):
20
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
21
+
22
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
23
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
24
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
25
+
26
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
27
+
28
+ def forward(self, winsize):
29
+ return self.pos_table[:, :winsize].clone().detach()
30
+
31
+
32
+ def _get_activation_fn(activation):
33
+ """Return an activation function given a string"""
34
+ if activation == "relu":
35
+ return F.relu
36
+ if activation == "gelu":
37
+ return F.gelu
38
+ if activation == "glu":
39
+ return F.glu
40
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
41
+
42
+
43
+ def _get_clones(module, N):
44
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
45
+
46
+
47
+ class Transformer(nn.Module):
48
+ def __init__(
49
+ self,
50
+ d_model=512,
51
+ nhead=8,
52
+ num_encoder_layers=6,
53
+ num_decoder_layers=6,
54
+ dim_feedforward=2048,
55
+ dropout=0.1,
56
+ activation="relu",
57
+ normalize_before=False,
58
+ return_intermediate_dec=True,
59
+ ):
60
+ super().__init__()
61
+
62
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
63
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
64
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
65
+
66
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
67
+ decoder_norm = nn.LayerNorm(d_model)
68
+ self.decoder = TransformerDecoder(
69
+ decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
70
+ )
71
+
72
+ self._reset_parameters()
73
+
74
+ self.d_model = d_model
75
+ self.nhead = nhead
76
+
77
+ def _reset_parameters(self):
78
+ for p in self.parameters():
79
+ if p.dim() > 1:
80
+ nn.init.xavier_uniform_(p)
81
+
82
+ def forward(self, opt, src, query_embed, pos_embed):
83
+ # flatten NxCxHxW to HWxNxC
84
+
85
+ src = src.permute(1, 0, 2)
86
+ pos_embed = pos_embed.permute(1, 0, 2)
87
+ query_embed = query_embed.permute(1, 0, 2)
88
+
89
+ tgt = torch.zeros_like(query_embed)
90
+ memory = self.encoder(src, pos=pos_embed)
91
+
92
+ hs = self.decoder(tgt, memory, pos=pos_embed, query_pos=query_embed)
93
+ return hs
94
+
95
+
96
+ class TransformerEncoder(nn.Module):
97
+ def __init__(self, encoder_layer, num_layers, norm=None):
98
+ super().__init__()
99
+ self.layers = _get_clones(encoder_layer, num_layers)
100
+ self.num_layers = num_layers
101
+ self.norm = norm
102
+
103
+ def forward(self, src, mask=None, src_key_padding_mask=None, pos=None):
104
+ output = src + pos
105
+
106
+ for layer in self.layers:
107
+ output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
108
+
109
+ if self.norm is not None:
110
+ output = self.norm(output)
111
+
112
+ return output
113
+
114
+
115
+ class TransformerDecoder(nn.Module):
116
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
117
+ super().__init__()
118
+ self.layers = _get_clones(decoder_layer, num_layers)
119
+ self.num_layers = num_layers
120
+ self.norm = norm
121
+ self.return_intermediate = return_intermediate
122
+
123
+ def forward(
124
+ self,
125
+ tgt,
126
+ memory,
127
+ tgt_mask=None,
128
+ memory_mask=None,
129
+ tgt_key_padding_mask=None,
130
+ memory_key_padding_mask=None,
131
+ pos=None,
132
+ query_pos=None,
133
+ ):
134
+ output = tgt + pos + query_pos
135
+
136
+ intermediate = []
137
+
138
+ for layer in self.layers:
139
+ output = layer(
140
+ output,
141
+ memory,
142
+ tgt_mask=tgt_mask,
143
+ memory_mask=memory_mask,
144
+ tgt_key_padding_mask=tgt_key_padding_mask,
145
+ memory_key_padding_mask=memory_key_padding_mask,
146
+ pos=pos,
147
+ query_pos=query_pos,
148
+ )
149
+ if self.return_intermediate:
150
+ intermediate.append(self.norm(output))
151
+
152
+ if self.norm is not None:
153
+ output = self.norm(output)
154
+ if self.return_intermediate:
155
+ intermediate.pop()
156
+ intermediate.append(output)
157
+
158
+ if self.return_intermediate:
159
+ return torch.stack(intermediate)
160
+
161
+ return output.unsqueeze(0)
162
+
163
+
164
+ class TransformerEncoderLayer(nn.Module):
165
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False):
166
+ super().__init__()
167
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
168
+ # Implementation of Feedforward model
169
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
170
+ self.dropout = nn.Dropout(dropout)
171
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
172
+
173
+ self.norm1 = nn.LayerNorm(d_model)
174
+ self.norm2 = nn.LayerNorm(d_model)
175
+ self.dropout1 = nn.Dropout(dropout)
176
+ self.dropout2 = nn.Dropout(dropout)
177
+
178
+ self.activation = _get_activation_fn(activation)
179
+ self.normalize_before = normalize_before
180
+
181
+ def with_pos_embed(self, tensor, pos):
182
+ return tensor if pos is None else tensor + pos
183
+
184
+ def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
185
+ # q = k = self.with_pos_embed(src, pos)
186
+ src2 = self.self_attn(src, src, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
187
+ src = src + self.dropout1(src2)
188
+ src = self.norm1(src)
189
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
190
+ src = src + self.dropout2(src2)
191
+ src = self.norm2(src)
192
+ return src
193
+
194
+ def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
195
+ src2 = self.norm1(src)
196
+ # q = k = self.with_pos_embed(src2, pos)
197
+ src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
198
+ src = src + self.dropout1(src2)
199
+ src2 = self.norm2(src)
200
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
201
+ src = src + self.dropout2(src2)
202
+ return src
203
+
204
+ def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
205
+ if self.normalize_before:
206
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
207
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
208
+
209
+
210
+ class TransformerDecoderLayer(nn.Module):
211
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False):
212
+ super().__init__()
213
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
214
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
215
+ # Implementation of Feedforward model
216
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
217
+ self.dropout = nn.Dropout(dropout)
218
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
219
+
220
+ self.norm1 = nn.LayerNorm(d_model)
221
+ self.norm2 = nn.LayerNorm(d_model)
222
+ self.norm3 = nn.LayerNorm(d_model)
223
+ self.dropout1 = nn.Dropout(dropout)
224
+ self.dropout2 = nn.Dropout(dropout)
225
+ self.dropout3 = nn.Dropout(dropout)
226
+
227
+ self.activation = _get_activation_fn(activation)
228
+ self.normalize_before = normalize_before
229
+
230
+ def with_pos_embed(self, tensor, pos):
231
+ return tensor if pos is None else tensor + pos
232
+
233
+ def forward_post(
234
+ self,
235
+ tgt,
236
+ memory,
237
+ tgt_mask=None,
238
+ memory_mask=None,
239
+ tgt_key_padding_mask=None,
240
+ memory_key_padding_mask=None,
241
+ pos=None,
242
+ query_pos=None,
243
+ ):
244
+ # q = k = self.with_pos_embed(tgt, query_pos)
245
+ tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
246
+ tgt = tgt + self.dropout1(tgt2)
247
+ tgt = self.norm1(tgt)
248
+ tgt2 = self.multihead_attn(
249
+ query=tgt, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
250
+ )[0]
251
+ tgt = tgt + self.dropout2(tgt2)
252
+ tgt = self.norm2(tgt)
253
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
254
+ tgt = tgt + self.dropout3(tgt2)
255
+ tgt = self.norm3(tgt)
256
+ return tgt
257
+
258
+ def forward_pre(
259
+ self,
260
+ tgt,
261
+ memory,
262
+ tgt_mask=None,
263
+ memory_mask=None,
264
+ tgt_key_padding_mask=None,
265
+ memory_key_padding_mask=None,
266
+ pos=None,
267
+ query_pos=None,
268
+ ):
269
+ tgt2 = self.norm1(tgt)
270
+ # q = k = self.with_pos_embed(tgt2, query_pos)
271
+ tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
272
+ tgt = tgt + self.dropout1(tgt2)
273
+ tgt2 = self.norm2(tgt)
274
+ tgt2 = self.multihead_attn(
275
+ query=tgt2, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
276
+ )[0]
277
+ tgt = tgt + self.dropout2(tgt2)
278
+ tgt2 = self.norm3(tgt)
279
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
280
+ tgt = tgt + self.dropout3(tgt2)
281
+ return tgt
282
+
283
+ def forward(
284
+ self,
285
+ tgt,
286
+ memory,
287
+ tgt_mask=None,
288
+ memory_mask=None,
289
+ tgt_key_padding_mask=None,
290
+ memory_key_padding_mask=None,
291
+ pos=None,
292
+ query_pos=None,
293
+ ):
294
+ if self.normalize_before:
295
+ return self.forward_pre(
296
+ tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
297
+ )
298
+ return self.forward_post(
299
+ tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
300
+ )
core/utils.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from collections import defaultdict
4
+ import logging
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from scipy.io import loadmat
10
+
11
+ from configs.default import get_cfg_defaults
12
+
13
+
14
+ def _reset_parameters(model):
15
+ for p in model.parameters():
16
+ if p.dim() > 1:
17
+ nn.init.xavier_uniform_(p)
18
+
19
+
20
+ def get_video_style(video_name, style_type):
21
+ person_id, direction, emotion, level, *_ = video_name.split("_")
22
+ if style_type == "id_dir_emo_level":
23
+ style = "_".join([person_id, direction, emotion, level])
24
+ elif style_type == "emotion":
25
+ style = emotion
26
+ else:
27
+ raise ValueError("Unknown style type")
28
+
29
+ return style
30
+
31
+
32
+ def get_style_video_lists(video_list, style_type):
33
+ style2video_list = defaultdict(list)
34
+ for video in video_list:
35
+ style = get_video_style(video, style_type)
36
+ style2video_list[style].append(video)
37
+
38
+ return style2video_list
39
+
40
+
41
+ def get_face3d_clip(video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32):
42
+ """_summary_
43
+
44
+ Args:
45
+ video_name (_type_): _description_
46
+ video_root_dir (_type_): _description_
47
+ num_frames (_type_): _description_
48
+ start_idx (_type_): "random" , middle, int
49
+ dtype (_type_, optional): _description_. Defaults to torch.float32.
50
+
51
+ Raises:
52
+ ValueError: _description_
53
+ ValueError: _description_
54
+
55
+ Returns:
56
+ _type_: _description_
57
+ """
58
+ video_path = os.path.join(video_root_dir, video_name)
59
+ if video_path[-3:] == "mat":
60
+ face3d_all = loadmat(video_path)["coeff"]
61
+ face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
62
+ elif video_path[-3:] == "txt":
63
+ face3d_exp = np.loadtxt(video_path)
64
+ else:
65
+ raise ValueError("Invalid 3DMM file extension")
66
+
67
+ length = face3d_exp.shape[0]
68
+ clip_num_frames = num_frames
69
+ if start_idx == "random":
70
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
71
+ elif start_idx == "middle":
72
+ clip_start_idx = (length - clip_num_frames + 1) // 2
73
+ elif isinstance(start_idx, int):
74
+ clip_start_idx = start_idx
75
+ else:
76
+ raise ValueError(f"Invalid start_idx {start_idx}")
77
+
78
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
79
+ face3d_clip = torch.tensor(face3d_clip, dtype=dtype)
80
+
81
+ return face3d_clip
82
+
83
+
84
+ def get_video_style_clip(video_path, style_max_len, start_idx="random", dtype=torch.float32):
85
+ if video_path[-3:] == "mat":
86
+ face3d_all = loadmat(video_path)["coeff"]
87
+ face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
88
+ elif video_path[-3:] == "txt":
89
+ face3d_exp = np.loadtxt(video_path)
90
+ else:
91
+ raise ValueError("Invalid 3DMM file extension")
92
+
93
+ face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
94
+
95
+ length = face3d_exp.shape[0]
96
+ if length >= style_max_len:
97
+ clip_num_frames = style_max_len
98
+ if start_idx == "random":
99
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
100
+ elif start_idx == "middle":
101
+ clip_start_idx = (length - clip_num_frames + 1) // 2
102
+ elif isinstance(start_idx, int):
103
+ clip_start_idx = start_idx
104
+ else:
105
+ raise ValueError(f"Invalid start_idx {start_idx}")
106
+
107
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
108
+ pad_mask = torch.tensor([False] * style_max_len)
109
+ else:
110
+ padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
111
+ face3d_clip = torch.cat((face3d_exp, padding), dim=0)
112
+ pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
113
+
114
+ return face3d_clip, pad_mask
115
+
116
+
117
+ def get_audio_name_from_video(video_name):
118
+ audio_name = video_name[:-4] + "_seq.json"
119
+ return audio_name
120
+
121
+
122
+ def get_audio_window(audio, win_size):
123
+ """
124
+
125
+ Args:
126
+ audio (numpy.ndarray): (N,)
127
+
128
+ Returns:
129
+ audio_wins (numpy.ndarray): (N, W)
130
+ """
131
+ num_frames = len(audio)
132
+ ph_frames = []
133
+ for rid in range(0, num_frames):
134
+ ph = []
135
+ for i in range(rid - win_size, rid + win_size + 1):
136
+ if i < 0:
137
+ ph.append(31)
138
+ elif i >= num_frames:
139
+ ph.append(31)
140
+ else:
141
+ ph.append(audio[i])
142
+
143
+ ph_frames.append(ph)
144
+
145
+ audio_wins = np.array(ph_frames)
146
+
147
+ return audio_wins
148
+
149
+
150
+ def setup_config():
151
+ parser = argparse.ArgumentParser(description="voice2pose main program")
152
+ parser.add_argument("--config_file", default="", metavar="FILE", help="path to config file")
153
+ parser.add_argument("--resume_from", type=str, default=None, help="the checkpoint to resume from")
154
+ parser.add_argument("--test_only", action="store_true", help="perform testing and evaluation only")
155
+ parser.add_argument("--demo_input", type=str, default=None, help="path to input for demo")
156
+ parser.add_argument("--checkpoint", type=str, default=None, help="the checkpoint to test with")
157
+ parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
158
+ parser.add_argument(
159
+ "opts",
160
+ help="Modify config options using the command-line",
161
+ default=None,
162
+ nargs=argparse.REMAINDER,
163
+ )
164
+ parser.add_argument(
165
+ "--local_rank",
166
+ type=int,
167
+ help="local rank for DistributedDataParallel",
168
+ )
169
+ parser.add_argument(
170
+ "--master_port",
171
+ type=str,
172
+ default="12345",
173
+ )
174
+ args = parser.parse_args()
175
+
176
+ cfg = get_cfg_defaults()
177
+ cfg.merge_from_file(args.config_file)
178
+ cfg.merge_from_list(args.opts)
179
+ cfg.freeze()
180
+ return args, cfg
181
+
182
+
183
+ def setup_logger(base_path, exp_name):
184
+ rootLogger = logging.getLogger()
185
+ rootLogger.setLevel(logging.INFO)
186
+
187
+ logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")
188
+
189
+ log_path = "{0}/{1}.log".format(base_path, exp_name)
190
+ fileHandler = logging.FileHandler(log_path)
191
+ fileHandler.setFormatter(logFormatter)
192
+ rootLogger.addHandler(fileHandler)
193
+
194
+ consoleHandler = logging.StreamHandler()
195
+ consoleHandler.setFormatter(logFormatter)
196
+ rootLogger.addHandler(consoleHandler)
197
+ rootLogger.handlers[0].setLevel(logging.ERROR)
198
+
199
+ logging.info("log path: %s" % log_path)
200
+
201
+
202
+ def get_pose_params(mat_path):
203
+ """Get pose parameters from mat file
204
+
205
+ Args:
206
+ mat_path (str): path of mat file
207
+
208
+ Returns:
209
+ pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
210
+ """
211
+ mat_dict = loadmat(mat_path)
212
+
213
+ np_3dmm = mat_dict["coeff"]
214
+ angles = np_3dmm[:, 224:227]
215
+ translations = np_3dmm[:, 254:257]
216
+
217
+ np_trans_params = mat_dict["transform_params"]
218
+ crop = np_trans_params[:, -3:]
219
+
220
+ pose_params = np.concatenate((angles, translations, crop), axis=1)
221
+
222
+ return pose_params
223
+
224
+
225
+ def obtain_seq_index(index, num_frames, radius):
226
+ seq = list(range(index - radius, index + radius + 1))
227
+ seq = [min(max(item, 0), num_frames - 1) for item in seq]
228
+ return seq
demo.mp4 ADDED
Binary file (547 kB). View file
 
demo.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d29bc2d048a0be69193daea065734a3e76abbbe37e5ae4c8903d82f14ad92cb
3
+ size 227888
demo_download.mp4 ADDED
Binary file (457 kB). View file
 
demo_download.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d29bc2d048a0be69193daea065734a3e76abbbe37e5ae4c8903d82f14ad92cb
3
+ size 227888
env.yaml ADDED
File without changes
environment.yml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: styletalk
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - blas=1.0=mkl
10
+ - bzip2=1.0.8=h7b6447c_0
11
+ - ca-certificates=2023.08.22=h06a4308_0
12
+ - certifi=2020.6.20=pyhd3eb1b0_3
13
+ - cudatoolkit=11.1.1=ha002fc5_10
14
+ - ffmpeg=4.2.2=h20bf706_0
15
+ - freetype=2.12.1=h4a9f257_0
16
+ - gmp=6.2.1=h295c915_3
17
+ - gnutls=3.6.15=he1e5248_0
18
+ - intel-openmp=2021.4.0=h06a4308_3561
19
+ - jpeg=9b=h024ee3a_2
20
+ - lame=3.100=h7b6447c_0
21
+ - libedit=3.1.20221030=h5eee18b_0
22
+ - libffi=3.2.1=hf484d3e_1007
23
+ - libgcc-ng=11.2.0=h1234567_1
24
+ - libgomp=11.2.0=h1234567_1
25
+ - libidn2=2.3.4=h5eee18b_0
26
+ - libopus=1.3.1=h7b6447c_0
27
+ - libpng=1.6.39=h5eee18b_0
28
+ - libstdcxx-ng=11.2.0=h1234567_1
29
+ - libtasn1=4.19.0=h5eee18b_0
30
+ - libtiff=4.0.9=he6b73bb_1
31
+ - libunistring=0.9.10=h27cfd23_0
32
+ - libuv=1.44.2=h5eee18b_0
33
+ - libvpx=1.7.0=h439df22_0
34
+ - mkl=2021.4.0=h06a4308_640
35
+ - mkl-service=2.4.0=py37h7f8727e_0
36
+ - mkl_fft=1.3.1=py37h3e078e5_1
37
+ - mkl_random=1.2.2=py37h51133e4_0
38
+ - ncurses=6.4=h6a678d5_0
39
+ - nettle=3.7.3=hbbd107a_1
40
+ - ninja=1.10.2=h06a4308_5
41
+ - ninja-base=1.10.2=hd09550d_5
42
+ - numpy=1.21.5=py37h6c91a56_3
43
+ - numpy-base=1.21.5=py37ha15fc14_3
44
+ - olefile=0.46=py37_0
45
+ - openh264=2.1.1=h4ff587b_0
46
+ - openssl=1.0.2u=h7b6447c_0
47
+ - pip=22.3.1=py37h06a4308_0
48
+ - python=3.7.0=h6e4f718_3
49
+ - python_abi=3.7=2_cp37m
50
+ - pytorch=1.8.0=py3.7_cuda11.1_cudnn8.0.5_0
51
+ - readline=7.0=h7b6447c_5
52
+ - setuptools=65.6.3=py37h06a4308_0
53
+ - six=1.16.0=pyhd3eb1b0_1
54
+ - sqlite=3.33.0=h62c20be_0
55
+ - tk=8.6.12=h1ccaba5_0
56
+ - torchaudio=0.8.0=py37
57
+ - torchvision=0.9.0=py37_cu111
58
+ - typing_extensions=4.1.1=pyh06a4308_0
59
+ - wheel=0.38.4=py37h06a4308_0
60
+ - x264=1!157.20191217=h7b6447c_0
61
+ - xz=5.4.2=h5eee18b_0
62
+ - zlib=1.2.13=h5eee18b_0
63
+ - pip:
64
+ - av==10.0.0
65
+ - beautifulsoup4==4.12.2
66
+ - charset-normalizer==3.3.2
67
+ - ffmpeg-python==0.2.0
68
+ - filelock==3.12.2
69
+ - future==0.18.3
70
+ - gdown==4.7.1
71
+ - idna==3.6
72
+ - imageio==2.18.0
73
+ - joblib==1.3.2
74
+ - networkx==2.6.3
75
+ - opencv-python==4.4.0.46
76
+ - packaging==23.2
77
+ - pillow==9.1.0
78
+ - pysocks==1.7.1
79
+ - pywavelets==1.3.0
80
+ - pyyaml==6.0
81
+ - requests==2.31.0
82
+ - scikit-image==0.19.3
83
+ - scikit-learn==1.0.2
84
+ - scipy==1.7.3
85
+ - soupsieve==2.4.1
86
+ - threadpoolctl==3.1.0
87
+ - tifffile==2021.11.2
88
+ - tqdm==4.66.1
89
+ - urllib3==2.0.7
90
+ - yacs==0.1.8
91
+ prefix: /home/pixis/miniconda3/envs/styletalk
generators/__pycache__/base_function.cpython-37.pyc ADDED
Binary file (13.9 kB). View file
 
generators/__pycache__/face_model.cpython-37.pyc ADDED
Binary file (3.97 kB). View file
 
generators/__pycache__/flow_util.cpython-37.pyc ADDED
Binary file (1.95 kB). View file
 
generators/base_function.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torch.autograd import Function
8
+ from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
9
+
10
+
11
+ class LayerNorm2d(nn.Module):
12
+ def __init__(self, n_out, affine=True):
13
+ super(LayerNorm2d, self).__init__()
14
+ self.n_out = n_out
15
+ self.affine = affine
16
+
17
+ if self.affine:
18
+ self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
19
+ self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
20
+
21
+ def forward(self, x):
22
+ normalized_shape = x.size()[1:]
23
+ if self.affine:
24
+ return F.layer_norm(x, normalized_shape, \
25
+ self.weight.expand(normalized_shape),
26
+ self.bias.expand(normalized_shape))
27
+
28
+ else:
29
+ return F.layer_norm(x, normalized_shape)
30
+
31
+ class ADAINHourglass(nn.Module):
32
+ def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
33
+ super(ADAINHourglass, self).__init__()
34
+ self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
35
+ self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
36
+ self.output_nc = self.decoder.output_nc
37
+
38
+ def forward(self, x, z):
39
+ return self.decoder(self.encoder(x, z), z)
40
+
41
+
42
+
43
+ class ADAINEncoder(nn.Module):
44
+ def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
45
+ super(ADAINEncoder, self).__init__()
46
+ self.layers = layers
47
+ self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
48
+ for i in range(layers):
49
+ in_channels = min(ngf * (2**i), img_f)
50
+ out_channels = min(ngf *(2**(i+1)), img_f)
51
+ model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
52
+ setattr(self, 'encoder' + str(i), model)
53
+ self.output_nc = out_channels
54
+
55
+ def forward(self, x, z):
56
+ out = self.input_layer(x)
57
+ out_list = [out]
58
+ for i in range(self.layers):
59
+ model = getattr(self, 'encoder' + str(i))
60
+ out = model(out, z)
61
+ out_list.append(out)
62
+ return out_list
63
+
64
+ class ADAINDecoder(nn.Module):
65
+ """docstring for ADAINDecoder"""
66
+ def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
67
+ nonlinearity=nn.LeakyReLU(), use_spect=False):
68
+
69
+ super(ADAINDecoder, self).__init__()
70
+ self.encoder_layers = encoder_layers
71
+ self.decoder_layers = decoder_layers
72
+ self.skip_connect = skip_connect
73
+ use_transpose = True
74
+
75
+ for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
76
+ in_channels = min(ngf * (2**(i+1)), img_f)
77
+ in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
78
+ out_channels = min(ngf * (2**i), img_f)
79
+ model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
80
+ setattr(self, 'decoder' + str(i), model)
81
+
82
+ self.output_nc = out_channels*2 if self.skip_connect else out_channels
83
+
84
+ def forward(self, x, z):
85
+ out = x.pop() if self.skip_connect else x
86
+ for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
87
+ model = getattr(self, 'decoder' + str(i))
88
+ out = model(out, z)
89
+ out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
90
+ return out
91
+
92
+ class ADAINEncoderBlock(nn.Module):
93
+ def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
94
+ super(ADAINEncoderBlock, self).__init__()
95
+ kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
96
+ kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
97
+
98
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
99
+ self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
100
+
101
+
102
+ self.norm_0 = ADAIN(input_nc, feature_nc)
103
+ self.norm_1 = ADAIN(output_nc, feature_nc)
104
+ self.actvn = nonlinearity
105
+
106
+ def forward(self, x, z):
107
+ x = self.conv_0(self.actvn(self.norm_0(x, z)))
108
+ x = self.conv_1(self.actvn(self.norm_1(x, z)))
109
+ return x
110
+
111
+ class ADAINDecoderBlock(nn.Module):
112
+ def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
113
+ super(ADAINDecoderBlock, self).__init__()
114
+ # Attributes
115
+ self.actvn = nonlinearity
116
+ hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
117
+
118
+ kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
119
+ if use_transpose:
120
+ kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
121
+ else:
122
+ kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
123
+
124
+ # create conv layers
125
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
126
+ if use_transpose:
127
+ self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
128
+ self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
129
+ else:
130
+ self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
131
+ nn.Upsample(scale_factor=2))
132
+ self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
133
+ nn.Upsample(scale_factor=2))
134
+ # define normalization layers
135
+ self.norm_0 = ADAIN(input_nc, feature_nc)
136
+ self.norm_1 = ADAIN(hidden_nc, feature_nc)
137
+ self.norm_s = ADAIN(input_nc, feature_nc)
138
+
139
+ def forward(self, x, z):
140
+ x_s = self.shortcut(x, z)
141
+ dx = self.conv_0(self.actvn(self.norm_0(x, z)))
142
+ dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
143
+ out = x_s + dx
144
+ return out
145
+
146
+ def shortcut(self, x, z):
147
+ x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
148
+ return x_s
149
+
150
+
151
+ def spectral_norm(module, use_spect=True):
152
+ """use spectral normal layer to stable the training process"""
153
+ if use_spect:
154
+ return SpectralNorm(module)
155
+ else:
156
+ return module
157
+
158
+
159
+ class ADAIN(nn.Module):
160
+ def __init__(self, norm_nc, feature_nc):
161
+ super().__init__()
162
+
163
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
164
+
165
+ nhidden = 128
166
+ use_bias=True
167
+
168
+ self.mlp_shared = nn.Sequential(
169
+ nn.Linear(feature_nc, nhidden, bias=use_bias),
170
+ nn.ReLU()
171
+ )
172
+ self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
173
+ self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
174
+
175
+ def forward(self, x, feature):
176
+
177
+ # Part 1. generate parameter-free normalized activations
178
+ normalized = self.param_free_norm(x)
179
+
180
+ # Part 2. produce scaling and bias conditioned on feature
181
+ feature = feature.view(feature.size(0), -1)
182
+ actv = self.mlp_shared(feature)
183
+ gamma = self.mlp_gamma(actv)
184
+ beta = self.mlp_beta(actv)
185
+
186
+ # apply scale and bias
187
+ gamma = gamma.view(*gamma.size()[:2], 1,1)
188
+ beta = beta.view(*beta.size()[:2], 1,1)
189
+ out = normalized * (1 + gamma) + beta
190
+ return out
191
+
192
+
193
+ class FineEncoder(nn.Module):
194
+ """docstring for Encoder"""
195
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
196
+ super(FineEncoder, self).__init__()
197
+ self.layers = layers
198
+ self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
199
+ for i in range(layers):
200
+ in_channels = min(ngf*(2**i), img_f)
201
+ out_channels = min(ngf*(2**(i+1)), img_f)
202
+ model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
203
+ setattr(self, 'down' + str(i), model)
204
+ self.output_nc = out_channels
205
+
206
+ def forward(self, x):
207
+ x = self.first(x)
208
+ out=[x]
209
+ for i in range(self.layers):
210
+ model = getattr(self, 'down'+str(i))
211
+ x = model(x)
212
+ out.append(x)
213
+ return out
214
+
215
+ class FineDecoder(nn.Module):
216
+ """docstring for FineDecoder"""
217
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
218
+ super(FineDecoder, self).__init__()
219
+ self.layers = layers
220
+ for i in range(layers)[::-1]:
221
+ in_channels = min(ngf*(2**(i+1)), img_f)
222
+ out_channels = min(ngf*(2**i), img_f)
223
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
224
+ res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
225
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
226
+
227
+ setattr(self, 'up' + str(i), up)
228
+ setattr(self, 'res' + str(i), res)
229
+ setattr(self, 'jump' + str(i), jump)
230
+
231
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
232
+
233
+ self.output_nc = out_channels
234
+
235
+ def forward(self, x, z):
236
+ out = x.pop()
237
+ for i in range(self.layers)[::-1]:
238
+ res_model = getattr(self, 'res' + str(i))
239
+ up_model = getattr(self, 'up' + str(i))
240
+ jump_model = getattr(self, 'jump' + str(i))
241
+ out = res_model(out, z)
242
+ out = up_model(out)
243
+ out = jump_model(x.pop()) + out
244
+ out_image = self.final(out)
245
+ return out_image
246
+
247
+ class FirstBlock2d(nn.Module):
248
+ """
249
+ Downsampling block for use in encoder.
250
+ """
251
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
252
+ super(FirstBlock2d, self).__init__()
253
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
254
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
255
+
256
+ if type(norm_layer) == type(None):
257
+ self.model = nn.Sequential(conv, nonlinearity)
258
+ else:
259
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
260
+
261
+
262
+ def forward(self, x):
263
+ out = self.model(x)
264
+ return out
265
+
266
+ class DownBlock2d(nn.Module):
267
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
268
+ super(DownBlock2d, self).__init__()
269
+
270
+
271
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
272
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
273
+ pool = nn.AvgPool2d(kernel_size=(2, 2))
274
+
275
+ if type(norm_layer) == type(None):
276
+ self.model = nn.Sequential(conv, nonlinearity, pool)
277
+ else:
278
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
279
+
280
+ def forward(self, x):
281
+ out = self.model(x)
282
+ return out
283
+
284
+ class UpBlock2d(nn.Module):
285
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
286
+ super(UpBlock2d, self).__init__()
287
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
288
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
289
+ if type(norm_layer) == type(None):
290
+ self.model = nn.Sequential(conv, nonlinearity)
291
+ else:
292
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
293
+
294
+ def forward(self, x):
295
+ out = self.model(F.interpolate(x, scale_factor=2))
296
+ return out
297
+
298
+ class FineADAINResBlocks(nn.Module):
299
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
300
+ super(FineADAINResBlocks, self).__init__()
301
+ self.num_block = num_block
302
+ for i in range(num_block):
303
+ model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
304
+ setattr(self, 'res'+str(i), model)
305
+
306
+ def forward(self, x, z):
307
+ for i in range(self.num_block):
308
+ model = getattr(self, 'res'+str(i))
309
+ x = model(x, z)
310
+ return x
311
+
312
+ class Jump(nn.Module):
313
+ def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
314
+ super(Jump, self).__init__()
315
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
316
+ conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
317
+
318
+ if type(norm_layer) == type(None):
319
+ self.model = nn.Sequential(conv, nonlinearity)
320
+ else:
321
+ self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
322
+
323
+ def forward(self, x):
324
+ out = self.model(x)
325
+ return out
326
+
327
+ class FineADAINResBlock2d(nn.Module):
328
+ """
329
+ Define an Residual block for different types
330
+ """
331
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
332
+ super(FineADAINResBlock2d, self).__init__()
333
+
334
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
335
+
336
+ self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
337
+ self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
338
+ self.norm1 = ADAIN(input_nc, feature_nc)
339
+ self.norm2 = ADAIN(input_nc, feature_nc)
340
+
341
+ self.actvn = nonlinearity
342
+
343
+
344
+ def forward(self, x, z):
345
+ dx = self.actvn(self.norm1(self.conv1(x), z))
346
+ dx = self.norm2(self.conv2(x), z)
347
+ out = dx + x
348
+ return out
349
+
350
+ class FinalBlock2d(nn.Module):
351
+ """
352
+ Define the output layer
353
+ """
354
+ def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
355
+ super(FinalBlock2d, self).__init__()
356
+
357
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
358
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
359
+
360
+ if tanh_or_sigmoid == 'sigmoid':
361
+ out_nonlinearity = nn.Sigmoid()
362
+ else:
363
+ out_nonlinearity = nn.Tanh()
364
+
365
+ self.model = nn.Sequential(conv, out_nonlinearity)
366
+ def forward(self, x):
367
+ out = self.model(x)
368
+ return out
generators/face_model.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import generators.flow_util as flow_util
9
+ from generators.base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
10
+
11
+ class FaceGenerator(nn.Module):
12
+ def __init__(
13
+ self,
14
+ mapping_net,
15
+ warpping_net,
16
+ editing_net,
17
+ common
18
+ ):
19
+ super(FaceGenerator, self).__init__()
20
+ self.mapping_net = MappingNet(**mapping_net)
21
+ self.warpping_net = WarpingNet(**warpping_net, **common)
22
+ self.editing_net = EditingNet(**editing_net, **common)
23
+
24
+ def forward(
25
+ self,
26
+ input_image,
27
+ driving_source,
28
+ stage=None
29
+ ):
30
+ if stage == 'warp':
31
+ descriptor = self.mapping_net(driving_source)
32
+ output = self.warpping_net(input_image, descriptor)
33
+ else:
34
+ descriptor = self.mapping_net(driving_source)
35
+ output = self.warpping_net(input_image, descriptor)
36
+ output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
37
+ return output
38
+
39
+ class MappingNet(nn.Module):
40
+ def __init__(self, coeff_nc, descriptor_nc, layer):
41
+ super( MappingNet, self).__init__()
42
+
43
+ self.layer = layer
44
+ nonlinearity = nn.LeakyReLU(0.1)
45
+
46
+ self.first = nn.Sequential(
47
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
48
+
49
+ for i in range(layer):
50
+ net = nn.Sequential(nonlinearity,
51
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
52
+ setattr(self, 'encoder' + str(i), net)
53
+
54
+ self.pooling = nn.AdaptiveAvgPool1d(1)
55
+ self.output_nc = descriptor_nc
56
+
57
+ def forward(self, input_3dmm):
58
+ out = self.first(input_3dmm)
59
+ for i in range(self.layer):
60
+ model = getattr(self, 'encoder' + str(i))
61
+ out = model(out) + out[:,:,3:-3]
62
+ out = self.pooling(out)
63
+ return out
64
+
65
+ class WarpingNet(nn.Module):
66
+ def __init__(
67
+ self,
68
+ image_nc,
69
+ descriptor_nc,
70
+ base_nc,
71
+ max_nc,
72
+ encoder_layer,
73
+ decoder_layer,
74
+ use_spect
75
+ ):
76
+ super( WarpingNet, self).__init__()
77
+
78
+ nonlinearity = nn.LeakyReLU(0.1)
79
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
80
+ kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
81
+
82
+ self.descriptor_nc = descriptor_nc
83
+ self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
84
+ max_nc, encoder_layer, decoder_layer, **kwargs)
85
+
86
+ self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
87
+ nonlinearity,
88
+ nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
89
+
90
+ self.pool = nn.AdaptiveAvgPool2d(1)
91
+
92
+ def forward(self, input_image, descriptor):
93
+ final_output={}
94
+ output = self.hourglass(input_image, descriptor)
95
+ final_output['flow_field'] = self.flow_out(output)
96
+
97
+ deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
98
+ final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
99
+ return final_output
100
+
101
+
102
+ class EditingNet(nn.Module):
103
+ def __init__(
104
+ self,
105
+ image_nc,
106
+ descriptor_nc,
107
+ layer,
108
+ base_nc,
109
+ max_nc,
110
+ num_res_blocks,
111
+ use_spect):
112
+ super(EditingNet, self).__init__()
113
+
114
+ nonlinearity = nn.LeakyReLU(0.1)
115
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
116
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
117
+ self.descriptor_nc = descriptor_nc
118
+
119
+ # encoder part
120
+ self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
121
+ self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
122
+
123
+ def forward(self, input_image, warp_image, descriptor):
124
+ x = torch.cat([input_image, warp_image], 1)
125
+ x = self.encoder(x)
126
+ gen_image = self.decoder(x, descriptor)
127
+ return gen_image
generators/flow_util.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def convert_flow_to_deformation(flow):
4
+ r"""convert flow fields to deformations.
5
+
6
+ Args:
7
+ flow (tensor): Flow field obtained by the model
8
+ Returns:
9
+ deformation (tensor): The deformation used for warpping
10
+ """
11
+ b,c,h,w = flow.shape
12
+ flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
13
+ grid = make_coordinate_grid(flow)
14
+ deformation = grid + flow_norm.permute(0,2,3,1)
15
+ return deformation
16
+
17
+ def make_coordinate_grid(flow):
18
+ r"""obtain coordinate grid with the same size as the flow filed.
19
+
20
+ Args:
21
+ flow (tensor): Flow field obtained by the model
22
+ Returns:
23
+ grid (tensor): The grid with the same size as the input flow
24
+ """
25
+ b,c,h,w = flow.shape
26
+
27
+ x = torch.arange(w).to(flow)
28
+ y = torch.arange(h).to(flow)
29
+
30
+ x = (2 * (x / (w - 1)) - 1)
31
+ y = (2 * (y / (h - 1)) - 1)
32
+
33
+ yy = y.view(-1, 1).repeat(1, w)
34
+ xx = x.view(1, -1).repeat(h, 1)
35
+
36
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
37
+ meshed = meshed.expand(b, -1, -1, -1)
38
+ return meshed
39
+
40
+
41
+ def warp_image(source_image, deformation):
42
+ r"""warp the input image according to the deformation
43
+
44
+ Args:
45
+ source_image (tensor): source images to be warpped
46
+ deformation (tensor): deformations used to warp the images; value in range (-1, 1)
47
+ Returns:
48
+ output (tensor): the warpped images
49
+ """
50
+ _, h_old, w_old, _ = deformation.shape
51
+ _, _, h, w = source_image.shape
52
+ if h_old != h or w_old != w:
53
+ deformation = deformation.permute(0, 3, 1, 2)
54
+ deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
55
+ deformation = deformation.permute(0, 2, 3, 1)
56
+ return torch.nn.functional.grid_sample(source_image, deformation)
inference_for_demo.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import json
4
+ import os
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torchvision
9
+ import torchvision.transforms as transforms
10
+ from PIL import Image
11
+
12
+ from core.networks.styletalk import StyleTalk
13
+ from core.utils import get_audio_window, get_pose_params, get_video_style_clip, obtain_seq_index
14
+ from configs.default import get_cfg_defaults
15
+
16
+
17
+ @torch.no_grad()
18
+ def get_eval_model(cfg):
19
+ model = StyleTalk(cfg).cuda()
20
+ content_encoder = model.content_encoder
21
+ style_encoder = model.style_encoder
22
+ decoder = model.decoder
23
+ checkpoint = torch.load(cfg.INFERENCE.CHECKPOINT)
24
+ model_state_dict = checkpoint["model_state_dict"]
25
+ content_encoder_dict = {k[16:]: v for k, v in model_state_dict.items() if k[:16] == "content_encoder."}
26
+ content_encoder.load_state_dict(content_encoder_dict, strict=True)
27
+ style_encoder_dict = {k[14:]: v for k, v in model_state_dict.items() if k[:14] == "style_encoder."}
28
+ style_encoder.load_state_dict(style_encoder_dict, strict=True)
29
+ decoder_dict = {k[8:]: v for k, v in model_state_dict.items() if k[:8] == "decoder."}
30
+ decoder.load_state_dict(decoder_dict, strict=True)
31
+ model.eval()
32
+ return content_encoder, style_encoder, decoder
33
+
34
+
35
+ @torch.no_grad()
36
+ def render_video(
37
+ net_G, src_img_path, exp_path, wav_path, output_path, silent=False, semantic_radius=13, fps=30, split_size=64
38
+ ):
39
+
40
+ target_exp_seq = np.load(exp_path)
41
+
42
+ frame = cv2.imread(src_img_path)
43
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
44
+ src_img_raw = Image.fromarray(frame)
45
+ image_transform = transforms.Compose(
46
+ [
47
+ transforms.ToTensor(),
48
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
49
+ ]
50
+ )
51
+ src_img = image_transform(src_img_raw)
52
+
53
+ target_win_exps = []
54
+ for frame_idx in range(len(target_exp_seq)):
55
+ win_indices = obtain_seq_index(frame_idx, target_exp_seq.shape[0], semantic_radius)
56
+ win_exp = torch.tensor(target_exp_seq[win_indices]).permute(1, 0)
57
+ # (73, 27)
58
+ target_win_exps.append(win_exp)
59
+
60
+ target_exp_concat = torch.stack(target_win_exps, dim=0)
61
+ target_splited_exps = torch.split(target_exp_concat, split_size, dim=0)
62
+ output_imgs = []
63
+ for win_exp in target_splited_exps:
64
+ win_exp = win_exp.cuda()
65
+ cur_src_img = src_img.expand(win_exp.shape[0], -1, -1, -1).cuda()
66
+ output_dict = net_G(cur_src_img, win_exp)
67
+ output_imgs.append(output_dict["fake_image"].cpu().clamp_(-1, 1))
68
+
69
+ output_imgs = torch.cat(output_imgs, 0)
70
+ transformed_imgs = ((output_imgs + 1) / 2 * 255).to(torch.uint8).permute(0, 2, 3, 1)
71
+
72
+ if silent:
73
+ torchvision.io.write_video(output_path, transformed_imgs.cpu(), fps)
74
+ else:
75
+ silent_video_path = "silent.mp4"
76
+ torchvision.io.write_video(silent_video_path, transformed_imgs.cpu(), fps)
77
+ os.system(f"ffmpeg -loglevel quiet -y -i {silent_video_path} -i {wav_path} -shortest {output_path}")
78
+ os.remove(silent_video_path)
79
+
80
+
81
+ @torch.no_grad()
82
+ def get_netG(checkpoint_path):
83
+ from generators.face_model import FaceGenerator
84
+ import yaml
85
+
86
+ with open("configs/renderer_conf.yaml", "r") as f:
87
+ renderer_config = yaml.load(f, Loader=yaml.FullLoader)
88
+
89
+ renderer = FaceGenerator(**renderer_config).to(torch.cuda.current_device())
90
+
91
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
92
+ renderer.load_state_dict(checkpoint["net_G_ema"], strict=False)
93
+
94
+ renderer.eval()
95
+
96
+ return renderer
97
+
98
+
99
+ @torch.no_grad()
100
+ def generate_expression_params(
101
+ cfg, audio_path, style_clip_path, pose_path, output_path, content_encoder, style_encoder, decoder
102
+ ):
103
+ with open(audio_path, "r") as f:
104
+ audio = json.load(f)
105
+
106
+ audio_win = get_audio_window(audio, cfg.WIN_SIZE)
107
+ audio_win = torch.tensor(audio_win).cuda()
108
+ content = content_encoder(audio_win.unsqueeze(0))
109
+
110
+ style_clip, pad_mask = get_video_style_clip(style_clip_path, style_max_len=256, start_idx=0)
111
+ style_code = style_encoder(
112
+ style_clip.unsqueeze(0).cuda(), pad_mask.unsqueeze(0).cuda() if pad_mask is not None else None
113
+ )
114
+
115
+ gen_exp_stack = decoder(content, style_code)
116
+ gen_exp = gen_exp_stack[0].cpu().numpy()
117
+
118
+ pose_ext = pose_path[-3:]
119
+ pose = None
120
+ if pose_ext == "npy":
121
+ pose = np.load(pose_path)
122
+ elif pose_ext == "mat":
123
+ pose = get_pose_params(pose_path)
124
+ # (L, 9)
125
+
126
+ selected_pose = None
127
+ if len(pose) >= len(gen_exp):
128
+ selected_pose = pose[: len(gen_exp)]
129
+ else:
130
+ selected_pose = pose[-1].unsqueeze(0).repeat(len(gen_exp), 1)
131
+ selected_pose[: len(pose)] = pose
132
+
133
+ gen_exp_pose = np.concatenate((gen_exp, selected_pose), axis=1)
134
+ np.save(output_path, gen_exp_pose)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser(description="inference for demo")
139
+ parser.add_argument(
140
+ "--styletalk_checkpoint",
141
+ type=str,
142
+ default="checkpoints/styletalk_checkpoint.pth",
143
+ help="the checkpoint to test with",
144
+ )
145
+ parser.add_argument(
146
+ "--renderer_checkpoint",
147
+ type=str,
148
+ default="checkpoints/renderer_checkpoint.pt",
149
+ help="renderer checkpoint",
150
+ )
151
+ parser.add_argument("--audio_path", type=str, default="", help="path for phoneme")
152
+ parser.add_argument("--style_clip_path", type=str, default="", help="path for style_clip_mat")
153
+ parser.add_argument("--pose_path", type=str, default="", help="path for pose")
154
+ parser.add_argument("--src_img_path", type=str, default="test_images/KristiNoem1_0.jpg")
155
+ parser.add_argument("--wav_path", type=str, default="demo/data/KristiNoem_front_neutral_level1_002.wav")
156
+ parser.add_argument("--output_path", type=str, default="demo_output.npy", help="path for output")
157
+ args = parser.parse_args()
158
+
159
+ cfg = get_cfg_defaults()
160
+ cfg.INFERENCE.CHECKPOINT = args.styletalk_checkpoint
161
+ cfg.freeze()
162
+ print(f"checkpoint: {cfg.INFERENCE.CHECKPOINT}")
163
+
164
+ # load checkpoint
165
+ with torch.no_grad():
166
+ content_encoder, style_encoder, decoder = get_eval_model(cfg)
167
+ exp_param_path = f"{args.output_path[:-4]}.npy"
168
+ generate_expression_params(
169
+ cfg,
170
+ args.audio_path,
171
+ args.style_clip_path,
172
+ args.pose_path,
173
+ exp_param_path,
174
+ content_encoder,
175
+ style_encoder,
176
+ decoder,
177
+ )
178
+
179
+ image_renderer = get_netG(args.renderer_checkpoint)
180
+ render_video(
181
+ image_renderer,
182
+ args.src_img_path,
183
+ exp_param_path,
184
+ args.wav_path,
185
+ args.output_path,
186
+ split_size=4,
187
+ )
media/first_page.png ADDED

Git LFS Details

  • SHA256: f9aa8175737d6b7bcd8b2520f62fb21969287f0646f954ee973a655e049d626f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.88 MB
phindex.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"AA": 0, "AE": 1, "AH": 2, "AO": 3, "AW": 4, "AY": 5, "B": 6, "CH": 7, "D": 8, "DH": 9, "EH": 10, "ER": 11, "EY": 12, "F": 13, "G": 14, "HH": 15, "IH": 16, "IY": 17, "JH": 18, "K": 19, "L": 20, "M": 21, "N": 22, "NG": 23, "NSN": 24, "OW": 25, "OY": 26, "P": 27, "R": 28, "S": 29, "SH": 30, "SIL": 31, "T": 32, "TH": 33, "UH": 34, "UW": 35, "V": 36, "W": 37, "Y": 38, "Z": 39, "ZH": 40}
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ yacs==0.1.8
2
+ scipy==1.7.3
3
+ scikit-image==0.19.3
4
+ scikit-learn==1.0.2
5
+ PyYAML==6.0
6
+ Pillow==9.1.0
7
+ numpy==1.21.5
8
+ opencv-python==4.4.0.46
9
+ imageio==2.18.0
10
+ ffmpeg-python==0.2.0
11
+ av==10.0.0
samples/source_video/3DMM/KristiNoem.mat ADDED
Binary file (566 kB). View file
 
samples/source_video/3DMM/Obama_clip1.mat ADDED
Binary file (629 kB). View file
 
samples/source_video/3DMM/Obama_clip2.mat ADDED
Binary file (943 kB). View file
 
samples/source_video/3DMM/Obama_clip3.mat ADDED
Binary file (629 kB). View file