okezieowen commited on
Commit
8dbeb08
·
verified ·
1 Parent(s): 270fefb

Update handler.py

Browse files

Adding voice enrollment feature and serialization of inputs & outputs.

Files changed (1) hide show
  1. handler.py +39 -68
handler.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  import numpy as np
4
  import librosa
5
  import soundfile as sf
6
- import uuid
7
  import traceback
8
  import base64
9
  import io
@@ -45,12 +44,6 @@ class EndpointHandler:
45
  except Exception as e:
46
  raise RuntimeError(f"Failed to load SNAC model: {e}")
47
 
48
- self.ENROLLMENT_DIR = "enrollments"
49
-
50
- # Move to devices
51
- self.voice_id = None
52
-
53
-
54
  # Set up functions to format and encode text/audio
55
  def encode_text(self, text):
56
  return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False)
@@ -91,9 +84,8 @@ class EndpointHandler:
91
  base64-encoded audio data
92
 
93
  Returns:
94
- - voice_id (str): ID you can later use to synthesize speech
95
  """
96
- os.makedirs(self.ENROLLMENT_DIR, exist_ok=True)
97
  enrollment_data = []
98
 
99
  for text, base64_audio in enrollment_pairs:
@@ -103,30 +95,15 @@ class EndpointHandler:
103
  "text_ids": text_ids,
104
  "audio_codes": audio_codes
105
  })
 
 
 
 
 
106
 
107
- # Generate unique voice ID
108
- voice_id = f"voice_{uuid.uuid4().hex[:6]}"
109
- save_path = os.path.join(self.ENROLLMENT_DIR, f"{voice_id}.pt")
110
- torch.save(enrollment_data, save_path)
111
-
112
- self.voice_id = voice_id
113
-
114
- return voice_id
115
-
116
- def load_enrollment_by_id(self, voice_id):
117
- """
118
- Load encoded text/audio token blocks using voice ID
119
-
120
- Returns:
121
- - enrollment_data: list of dicts {text_ids, audio_codes}
122
- """
123
- path = os.path.join(self.ENROLLMENT_DIR, f"{voice_id}.pt")
124
- if not os.path.exists(path):
125
- raise FileNotFoundError(f"Voice ID '{voice_id}' not found.")
126
-
127
- enrollment_data = torch.load(path, map_location="cpu")
128
-
129
- return enrollment_data
130
 
131
  def prepare_audio_tokens_for_decoder(self, audio_codes_list):
132
  """
@@ -149,7 +126,6 @@ class EndpointHandler:
149
 
150
  return modified_audio_codes_list
151
 
152
-
153
  # Convert audio sample to codes and reconstruct
154
  def tokenize_audio(self, waveform):
155
  waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device)
@@ -174,39 +150,26 @@ class EndpointHandler:
174
  """
175
  Preprocess input data before inference
176
  """
177
- self.voice_cloning = data.get("parameters", {}).get("clone", False)
178
-
179
- if isinstance(data, dict) and "inputs" in data:
180
- target_text = data["inputs"]
181
- parameters = data.get("parameters", {})
182
- else:
183
- target_text = data
184
- parameters = {}
185
 
186
  # Extract parameters from request
 
 
 
 
187
  temperature = float(parameters.get("temperature", 0.6))
188
  top_p = float(parameters.get("top_p", 0.95))
189
  max_new_tokens = int(parameters.get("max_new_tokens", 1200))
190
  repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
191
- enrollments = parameters.get("enrollments", [])
192
- voice_id = parameters.get("voice_id", None)
193
 
194
  if self.voice_cloning:
 
195
 
196
- # Validate voice cloning input
197
- enrollment_path = lambda vid: os.path.join(self.ENROLLMENT_DIR, f"{vid}.pt")
198
-
199
- if voice_id:
200
- if not os.path.exists(enrollment_path(voice_id)):
201
- raise ValueError(f"Voice ID '{voice_id}' not found in {self.ENROLLMENT_DIR}")
202
- enrollment_data = self.load_enrollment_by_id(voice_id)
203
-
204
- elif enrollments:
205
- voice_id = self.enroll_user(enrollments)
206
- enrollment_data = self.load_enrollment_by_id(voice_id)
207
-
208
  else:
209
- raise ValueError("You must provide either a valid voice_id or enrollment pairs.")
 
210
 
211
  # Process pre-tokenized enrollment_data
212
  input_sequence = []
@@ -238,7 +201,7 @@ class EndpointHandler:
238
  """Handle standard text-to-speech"""
239
 
240
  # Extract parameters from request
241
- voice = parameters.get("voice", "tara")
242
  prompt = f"{voice}: {target_text}"
243
  input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
244
 
@@ -297,10 +260,20 @@ class EndpointHandler:
297
  Main entry point for the handler
298
  """
299
  try:
300
- preprocessed_inputs = self.preprocess(data)
301
- model_outputs = self.inference(preprocessed_inputs)
302
- response = self.postprocess(model_outputs)
303
- return response
 
 
 
 
 
 
 
 
 
 
304
  # Catch that error, baby
305
  except Exception as e:
306
  traceback.print_exc()
@@ -434,10 +407,8 @@ class EndpointHandler:
434
  # Encode WAV bytes as base64
435
  audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
436
 
437
- return {
438
- "generated_ids": generated_ids.tolist(),
439
- "audio_sample": audio_sample,
440
- "audio_b64": audio_b64,
441
- "sample_rate": 24000,
442
- "voice_id": self.voice_id
443
- }
 
3
  import numpy as np
4
  import librosa
5
  import soundfile as sf
 
6
  import traceback
7
  import base64
8
  import io
 
44
  except Exception as e:
45
  raise RuntimeError(f"Failed to load SNAC model: {e}")
46
 
 
 
 
 
 
 
47
  # Set up functions to format and encode text/audio
48
  def encode_text(self, text):
49
  return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False)
 
84
  base64-encoded audio data
85
 
86
  Returns:
87
+ - cloning_features (str): serialized enrollment data
88
  """
 
89
  enrollment_data = []
90
 
91
  for text, base64_audio in enrollment_pairs:
 
95
  "text_ids": text_ids,
96
  "audio_codes": audio_codes
97
  })
98
+
99
+ # Serialize enrollment data
100
+ buffer = io.BytesIO()
101
+ torch.save(enrollment_data, buffer)
102
+ buffer.seek(0)
103
 
104
+ # Encode as base64 string and assign to attribute
105
+ cloning_features = base64.b64encode(buffer.read()).decode('utf-8')
106
+ return cloning_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def prepare_audio_tokens_for_decoder(self, audio_codes_list):
109
  """
 
126
 
127
  return modified_audio_codes_list
128
 
 
129
  # Convert audio sample to codes and reconstruct
130
  def tokenize_audio(self, waveform):
131
  waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device)
 
150
  """
151
  Preprocess input data before inference
152
  """
153
+ self.voice_cloning = data.get("clone", False)
 
 
 
 
 
 
 
154
 
155
  # Extract parameters from request
156
+ target_text = data["inputs"]
157
+ parameters = data.get("parameters", {})
158
+ cloning_features = data.get("cloning_features", None)
159
+
160
  temperature = float(parameters.get("temperature", 0.6))
161
  top_p = float(parameters.get("top_p", 0.95))
162
  max_new_tokens = int(parameters.get("max_new_tokens", 1200))
163
  repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
 
 
164
 
165
  if self.voice_cloning:
166
+ """Handle voice cloning using cloning features"""
167
 
168
+ if not cloning_features:
169
+ raise ValueError("No cloning features were provided")
 
 
 
 
 
 
 
 
 
 
170
  else:
171
+ # Decode back into tensors
172
+ enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features)))
173
 
174
  # Process pre-tokenized enrollment_data
175
  input_sequence = []
 
201
  """Handle standard text-to-speech"""
202
 
203
  # Extract parameters from request
204
+ voice = data.get("voice", "Eniola")
205
  prompt = f"{voice}: {target_text}"
206
  input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
207
 
 
260
  Main entry point for the handler
261
  """
262
  try:
263
+ enroll_user = data.get("enroll_user", False)
264
+
265
+ if enroll_user:
266
+ # We extract cloning features for enrollment
267
+ enrollment_pairs = data.get("enrollments", [])
268
+ cloning_features = self.enroll_user(enrollment_pairs)
269
+ return {"cloning_features": cloning_features}
270
+ else:
271
+ # We want to generate speech using preset cloning features
272
+ preprocessed_inputs = self.preprocess(data)
273
+ model_outputs = self.inference(preprocessed_inputs)
274
+ response = self.postprocess(model_outputs)
275
+ return response
276
+
277
  # Catch that error, baby
278
  except Exception as e:
279
  traceback.print_exc()
 
407
  # Encode WAV bytes as base64
408
  audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
409
 
410
+ return {
411
+ "audio_sample": audio_sample,
412
+ "audio_b64": audio_b64,
413
+ "sample_rate": 24000,
414
+ }