File size: 1,945 Bytes
c4d7aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import os

class ResidualBlock(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
        super().__init__()
        self.convs1 = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size, 1, dilation=d, padding=d)
            for d in dilation
        ])
        self.convs2 = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=1)
            for _ in dilation
        ])

    def forward(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, 0.1)
            xt = c1(xt)
            xt = F.leaky_relu(xt, 0.1)
            xt = c2(xt)
            x = xt + x
        return x

class RVCModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        model_cfg = config["model"]
        
        self.encoder = nn.Sequential(
            nn.Conv1d(128, model_cfg["upsample_initial_channel"], 7, 1, 3),
            *[ResidualBlock(model_cfg["upsample_initial_channel"]) for _ in range(3)]
        )
        
        self.decoder = nn.Sequential(
            nn.Conv1d(model_cfg["upsample_initial_channel"], 128, 7, 1, 3),
        )
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
    
    def convert_voice(self, audio_path):
        return audio_path
    
    @classmethod
    def from_pretrained(cls, model_path):
        config_path = os.path.join(model_path, "config.json")
        with open(config_path, "r") as f:
            config = json.load(f)
        
        model = cls(config)
        
        model_file = os.path.join(model_path, "model.pth")
        if os.path.exists(model_file):
            model.load_state_dict(torch.load(model_file, map_location="cpu"))
        
        return model