Update handler.py
Browse filesAdding voice enrollment feature and serialization of inputs & outputs.
- handler.py +39 -68
handler.py
CHANGED
|
@@ -3,7 +3,6 @@ import torch
|
|
| 3 |
import numpy as np
|
| 4 |
import librosa
|
| 5 |
import soundfile as sf
|
| 6 |
-
import uuid
|
| 7 |
import traceback
|
| 8 |
import base64
|
| 9 |
import io
|
|
@@ -45,12 +44,6 @@ class EndpointHandler:
|
|
| 45 |
except Exception as e:
|
| 46 |
raise RuntimeError(f"Failed to load SNAC model: {e}")
|
| 47 |
|
| 48 |
-
self.ENROLLMENT_DIR = "enrollments"
|
| 49 |
-
|
| 50 |
-
# Move to devices
|
| 51 |
-
self.voice_id = None
|
| 52 |
-
|
| 53 |
-
|
| 54 |
# Set up functions to format and encode text/audio
|
| 55 |
def encode_text(self, text):
|
| 56 |
return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False)
|
|
@@ -91,9 +84,8 @@ class EndpointHandler:
|
|
| 91 |
base64-encoded audio data
|
| 92 |
|
| 93 |
Returns:
|
| 94 |
-
-
|
| 95 |
"""
|
| 96 |
-
os.makedirs(self.ENROLLMENT_DIR, exist_ok=True)
|
| 97 |
enrollment_data = []
|
| 98 |
|
| 99 |
for text, base64_audio in enrollment_pairs:
|
|
@@ -103,30 +95,15 @@ class EndpointHandler:
|
|
| 103 |
"text_ids": text_ids,
|
| 104 |
"audio_codes": audio_codes
|
| 105 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
-
#
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
torch.save(enrollment_data, save_path)
|
| 111 |
-
|
| 112 |
-
self.voice_id = voice_id
|
| 113 |
-
|
| 114 |
-
return voice_id
|
| 115 |
-
|
| 116 |
-
def load_enrollment_by_id(self, voice_id):
|
| 117 |
-
"""
|
| 118 |
-
Load encoded text/audio token blocks using voice ID
|
| 119 |
-
|
| 120 |
-
Returns:
|
| 121 |
-
- enrollment_data: list of dicts {text_ids, audio_codes}
|
| 122 |
-
"""
|
| 123 |
-
path = os.path.join(self.ENROLLMENT_DIR, f"{voice_id}.pt")
|
| 124 |
-
if not os.path.exists(path):
|
| 125 |
-
raise FileNotFoundError(f"Voice ID '{voice_id}' not found.")
|
| 126 |
-
|
| 127 |
-
enrollment_data = torch.load(path, map_location="cpu")
|
| 128 |
-
|
| 129 |
-
return enrollment_data
|
| 130 |
|
| 131 |
def prepare_audio_tokens_for_decoder(self, audio_codes_list):
|
| 132 |
"""
|
|
@@ -149,7 +126,6 @@ class EndpointHandler:
|
|
| 149 |
|
| 150 |
return modified_audio_codes_list
|
| 151 |
|
| 152 |
-
|
| 153 |
# Convert audio sample to codes and reconstruct
|
| 154 |
def tokenize_audio(self, waveform):
|
| 155 |
waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device)
|
|
@@ -174,39 +150,26 @@ class EndpointHandler:
|
|
| 174 |
"""
|
| 175 |
Preprocess input data before inference
|
| 176 |
"""
|
| 177 |
-
self.voice_cloning = data.get("
|
| 178 |
-
|
| 179 |
-
if isinstance(data, dict) and "inputs" in data:
|
| 180 |
-
target_text = data["inputs"]
|
| 181 |
-
parameters = data.get("parameters", {})
|
| 182 |
-
else:
|
| 183 |
-
target_text = data
|
| 184 |
-
parameters = {}
|
| 185 |
|
| 186 |
# Extract parameters from request
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
temperature = float(parameters.get("temperature", 0.6))
|
| 188 |
top_p = float(parameters.get("top_p", 0.95))
|
| 189 |
max_new_tokens = int(parameters.get("max_new_tokens", 1200))
|
| 190 |
repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
|
| 191 |
-
enrollments = parameters.get("enrollments", [])
|
| 192 |
-
voice_id = parameters.get("voice_id", None)
|
| 193 |
|
| 194 |
if self.voice_cloning:
|
|
|
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
if voice_id:
|
| 200 |
-
if not os.path.exists(enrollment_path(voice_id)):
|
| 201 |
-
raise ValueError(f"Voice ID '{voice_id}' not found in {self.ENROLLMENT_DIR}")
|
| 202 |
-
enrollment_data = self.load_enrollment_by_id(voice_id)
|
| 203 |
-
|
| 204 |
-
elif enrollments:
|
| 205 |
-
voice_id = self.enroll_user(enrollments)
|
| 206 |
-
enrollment_data = self.load_enrollment_by_id(voice_id)
|
| 207 |
-
|
| 208 |
else:
|
| 209 |
-
|
|
|
|
| 210 |
|
| 211 |
# Process pre-tokenized enrollment_data
|
| 212 |
input_sequence = []
|
|
@@ -238,7 +201,7 @@ class EndpointHandler:
|
|
| 238 |
"""Handle standard text-to-speech"""
|
| 239 |
|
| 240 |
# Extract parameters from request
|
| 241 |
-
voice =
|
| 242 |
prompt = f"{voice}: {target_text}"
|
| 243 |
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
| 244 |
|
|
@@ -297,10 +260,20 @@ class EndpointHandler:
|
|
| 297 |
Main entry point for the handler
|
| 298 |
"""
|
| 299 |
try:
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
# Catch that error, baby
|
| 305 |
except Exception as e:
|
| 306 |
traceback.print_exc()
|
|
@@ -434,10 +407,8 @@ class EndpointHandler:
|
|
| 434 |
# Encode WAV bytes as base64
|
| 435 |
audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
|
| 436 |
|
| 437 |
-
return {
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
"voice_id": self.voice_id
|
| 443 |
-
}
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import librosa
|
| 5 |
import soundfile as sf
|
|
|
|
| 6 |
import traceback
|
| 7 |
import base64
|
| 8 |
import io
|
|
|
|
| 44 |
except Exception as e:
|
| 45 |
raise RuntimeError(f"Failed to load SNAC model: {e}")
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
# Set up functions to format and encode text/audio
|
| 48 |
def encode_text(self, text):
|
| 49 |
return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False)
|
|
|
|
| 84 |
base64-encoded audio data
|
| 85 |
|
| 86 |
Returns:
|
| 87 |
+
- cloning_features (str): serialized enrollment data
|
| 88 |
"""
|
|
|
|
| 89 |
enrollment_data = []
|
| 90 |
|
| 91 |
for text, base64_audio in enrollment_pairs:
|
|
|
|
| 95 |
"text_ids": text_ids,
|
| 96 |
"audio_codes": audio_codes
|
| 97 |
})
|
| 98 |
+
|
| 99 |
+
# Serialize enrollment data
|
| 100 |
+
buffer = io.BytesIO()
|
| 101 |
+
torch.save(enrollment_data, buffer)
|
| 102 |
+
buffer.seek(0)
|
| 103 |
|
| 104 |
+
# Encode as base64 string and assign to attribute
|
| 105 |
+
cloning_features = base64.b64encode(buffer.read()).decode('utf-8')
|
| 106 |
+
return cloning_features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
def prepare_audio_tokens_for_decoder(self, audio_codes_list):
|
| 109 |
"""
|
|
|
|
| 126 |
|
| 127 |
return modified_audio_codes_list
|
| 128 |
|
|
|
|
| 129 |
# Convert audio sample to codes and reconstruct
|
| 130 |
def tokenize_audio(self, waveform):
|
| 131 |
waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device)
|
|
|
|
| 150 |
"""
|
| 151 |
Preprocess input data before inference
|
| 152 |
"""
|
| 153 |
+
self.voice_cloning = data.get("clone", False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
# Extract parameters from request
|
| 156 |
+
target_text = data["inputs"]
|
| 157 |
+
parameters = data.get("parameters", {})
|
| 158 |
+
cloning_features = data.get("cloning_features", None)
|
| 159 |
+
|
| 160 |
temperature = float(parameters.get("temperature", 0.6))
|
| 161 |
top_p = float(parameters.get("top_p", 0.95))
|
| 162 |
max_new_tokens = int(parameters.get("max_new_tokens", 1200))
|
| 163 |
repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
|
|
|
|
|
|
|
| 164 |
|
| 165 |
if self.voice_cloning:
|
| 166 |
+
"""Handle voice cloning using cloning features"""
|
| 167 |
|
| 168 |
+
if not cloning_features:
|
| 169 |
+
raise ValueError("No cloning features were provided")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
else:
|
| 171 |
+
# Decode back into tensors
|
| 172 |
+
enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features)))
|
| 173 |
|
| 174 |
# Process pre-tokenized enrollment_data
|
| 175 |
input_sequence = []
|
|
|
|
| 201 |
"""Handle standard text-to-speech"""
|
| 202 |
|
| 203 |
# Extract parameters from request
|
| 204 |
+
voice = data.get("voice", "Eniola")
|
| 205 |
prompt = f"{voice}: {target_text}"
|
| 206 |
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
| 207 |
|
|
|
|
| 260 |
Main entry point for the handler
|
| 261 |
"""
|
| 262 |
try:
|
| 263 |
+
enroll_user = data.get("enroll_user", False)
|
| 264 |
+
|
| 265 |
+
if enroll_user:
|
| 266 |
+
# We extract cloning features for enrollment
|
| 267 |
+
enrollment_pairs = data.get("enrollments", [])
|
| 268 |
+
cloning_features = self.enroll_user(enrollment_pairs)
|
| 269 |
+
return {"cloning_features": cloning_features}
|
| 270 |
+
else:
|
| 271 |
+
# We want to generate speech using preset cloning features
|
| 272 |
+
preprocessed_inputs = self.preprocess(data)
|
| 273 |
+
model_outputs = self.inference(preprocessed_inputs)
|
| 274 |
+
response = self.postprocess(model_outputs)
|
| 275 |
+
return response
|
| 276 |
+
|
| 277 |
# Catch that error, baby
|
| 278 |
except Exception as e:
|
| 279 |
traceback.print_exc()
|
|
|
|
| 407 |
# Encode WAV bytes as base64
|
| 408 |
audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
|
| 409 |
|
| 410 |
+
return {
|
| 411 |
+
"audio_sample": audio_sample,
|
| 412 |
+
"audio_b64": audio_b64,
|
| 413 |
+
"sample_rate": 24000,
|
| 414 |
+
}
|
|
|
|
|
|