Morgan Funtowicz commited on
Commit
69cb715
·
1 Parent(s): ead11a7

chore(qa): format

Browse files
Files changed (1) hide show
  1. endpoint.py +61 -38
endpoint.py CHANGED
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
33
 
34
 
35
  def chunk_audio_with_duration(
36
- audio: np.ndarray, maximum_duration_sec: int, sampling_rate: int
37
  ) -> Sequence[np.ndarray]:
38
  """
39
  Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`.
@@ -62,10 +62,10 @@ def compression_ratio(text: str) -> float:
62
 
63
 
64
  def create_prompt(
65
- audio: np.ndarray,
66
- sampling_rate: int,
67
- language: int,
68
- timestamp_marker: int,
69
  ):
70
  """
71
  Generate the right prompt with the specific parameters to submit for inference over Whisper
@@ -92,14 +92,14 @@ def create_prompt(
92
 
93
 
94
  def create_params(
95
- max_tokens: int, temperature: float, is_verbose: bool
96
  ) -> SamplingParams:
97
  """
98
-
99
- :param max_tokens:
100
- :param temperature:
101
- :param is_verbose:
102
- :return:
103
  """
104
  return SamplingParams.from_optional(
105
  # output_kind=RequestOutputKind.FINAL_ONLY, # Change if streaming
@@ -107,7 +107,7 @@ def create_params(
107
  skip_special_tokens=False,
108
  detokenize=False,
109
  temperature=temperature,
110
- logprobs=100 if is_verbose else None,
111
  )
112
 
113
 
@@ -122,13 +122,23 @@ def get_avg_logprob(logprobs: "SampleLogprobs") -> float:
122
 
123
 
124
  def process_chunk(
125
- tokenizer: "PreTrainedTokenizer",
126
- ids: np.ndarray,
127
- logprobs: "SampleLogprobs",
128
- request: TranscriptionRequest,
129
- segment_offset: int,
130
- timestamp_offset: int,
131
  ) -> Generator:
 
 
 
 
 
 
 
 
 
 
132
  # Some constants
133
  k_timestamp_token = lru_cache(tokenizer.convert_tokens_to_ids)(f"<|0.00|>")
134
 
@@ -187,10 +197,17 @@ def process_chunk(
187
 
188
 
189
  def process_chunks(
190
- tokenizer: "PreTrainedTokenizer",
191
- chunks: List["RequestOutput"],
192
- request: TranscriptionRequest,
193
  ) -> Tuple[List[Segment], str]:
 
 
 
 
 
 
 
194
  # k_nospeech_token = tokenizer.convert_tokens_to_ids("<|nospeech|>")
195
  # k_sot_token = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
196
  materialized_segments, materialized_segments_tokens_acc = [], []
@@ -205,7 +222,7 @@ def process_chunks(
205
  logprobs = generation.logprobs
206
 
207
  for segment, _is_continuation in process_chunk(
208
- tokenizer, ids, logprobs, request, segment_offset, time_offset
209
  ):
210
  materialized_segments.append(segment)
211
 
@@ -237,19 +254,19 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
237
  device="auto",
238
  dtype="bfloat16",
239
  kv_cache_dtype="fp8",
240
- enforce_eager=True,
241
  enable_prefix_caching=True,
242
- max_logprobs=100, # TODO(mfuntowicz) : Set from config?
243
  )
244
  )
245
 
246
  async def transcribe(
247
- self,
248
- ctx: Context,
249
- request: TranscriptionRequest,
250
- tokenizer: "PreTrainedTokenizer",
251
- audio_chunks: Iterable[np.ndarray],
252
- params: "SamplingParams"
253
  ) -> (List[Segment], str):
254
  async def __agenerate__(request_id: str, prompt, params):
255
  """
@@ -267,8 +284,6 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
267
  # Wrap tokenizer results with LRU cache to avoid vocabulary lookup
268
  convert_tokens_to_ids = lru_cache(tokenizer.convert_tokens_to_ids)
269
 
270
- # f"<|{timestamp_marker if is_verbose_response else 0:.2f}|>"
271
-
272
  coro_handles = []
273
  for audio_chunk_id, audio_chunk in enumerate(audio_chunks):
274
  # Generate suffixed request-id to submit and identify through vLLM scheduler
@@ -280,11 +295,17 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
280
  # Compute initial prompt for the segment
281
  is_verbose = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
282
  language = convert_tokens_to_ids(f"<|{request.language}|>")
283
- timestamp = convert_tokens_to_ids(f"<|0.00|>" if is_verbose else '<|notimestamps|>')
284
- prompt = create_prompt(audio_chunk, WhisperHandler.WHISPER_SAMPLING_RATE, language, timestamp)
 
 
 
 
285
 
286
  # Submit the task
287
- coro_handles += [asyncio.create_task(__agenerate__(request_id, prompt, params))]
 
 
288
 
289
  # Wait for all the segment to complete
290
  text_chunks = await asyncio.gather(*coro_handles)
@@ -296,14 +317,14 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
296
  return segments, text
297
 
298
  async def __call__(
299
- self, request: TranscriptionRequest, ctx: Context
300
  ) -> TranscriptionResponse:
301
  with logger.contextualize(request_id=ctx.request_id):
302
  with memoryview(request) as audio:
303
 
304
  # Check if we need to enable the verbose path
305
  is_verbose = (
306
- request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
307
  )
308
 
309
  # Retrieve the tokenizer and model config asynchronously while we decode audio
@@ -329,7 +350,9 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
329
  )
330
 
331
  # Submit audio pieces to the batcher and gather them all
332
- segments, text = await self.transcribe(ctx, request, await tokenizer, audio_chunks, params)
 
 
333
 
334
  match request.response_kind:
335
  case TranscriptionResponseKind.VERBOSE_JSON:
 
33
 
34
 
35
  def chunk_audio_with_duration(
36
+ audio: np.ndarray, maximum_duration_sec: int, sampling_rate: int
37
  ) -> Sequence[np.ndarray]:
38
  """
39
  Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`.
 
62
 
63
 
64
  def create_prompt(
65
+ audio: np.ndarray,
66
+ sampling_rate: int,
67
+ language: int,
68
+ timestamp_marker: int,
69
  ):
70
  """
71
  Generate the right prompt with the specific parameters to submit for inference over Whisper
 
92
 
93
 
94
  def create_params(
95
+ max_tokens: int, temperature: float, is_verbose: bool
96
  ) -> SamplingParams:
97
  """
98
+ Create sampling parameters to submit for inference through vLLM `generate`
99
+ :param max_tokens: Maximum number of tokens to generate
100
+ :param temperature: Sampling temperature for the softmax
101
+ :param is_verbose: Flag indicating whether the response is required to contains verbose output
102
+ :return: `SamplingParams`
103
  """
104
  return SamplingParams.from_optional(
105
  # output_kind=RequestOutputKind.FINAL_ONLY, # Change if streaming
 
107
  skip_special_tokens=False,
108
  detokenize=False,
109
  temperature=temperature,
110
+ logprobs=1 if is_verbose else None,
111
  )
112
 
113
 
 
122
 
123
 
124
  def process_chunk(
125
+ tokenizer: "PreTrainedTokenizer",
126
+ ids: np.ndarray,
127
+ logprobs: "SampleLogprobs",
128
+ request: TranscriptionRequest,
129
+ segment_offset: int,
130
+ timestamp_offset: int,
131
  ) -> Generator:
132
+ """
133
+ Decode a single transcribed audio chunk and generates all the segments associated
134
+ :param tokenizer:
135
+ :param ids:
136
+ :param logprobs:
137
+ :param request:
138
+ :param segment_offset:
139
+ :param timestamp_offset:
140
+ :return:
141
+ """
142
  # Some constants
143
  k_timestamp_token = lru_cache(tokenizer.convert_tokens_to_ids)(f"<|0.00|>")
144
 
 
197
 
198
 
199
  def process_chunks(
200
+ tokenizer: "PreTrainedTokenizer",
201
+ chunks: List["RequestOutput"],
202
+ request: TranscriptionRequest,
203
  ) -> Tuple[List[Segment], str]:
204
+ """
205
+ Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not
206
+ :param tokenizer: The tokenizer to use for decoding tokens
207
+ :param chunks: Transcribed audio chunks
208
+ :param request: Received request from the user
209
+ :return: `Tuple[List[Segment], str]` holding all the consolidated segments along with full transcribed text
210
+ """
211
  # k_nospeech_token = tokenizer.convert_tokens_to_ids("<|nospeech|>")
212
  # k_sot_token = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
213
  materialized_segments, materialized_segments_tokens_acc = [], []
 
222
  logprobs = generation.logprobs
223
 
224
  for segment, _is_continuation in process_chunk(
225
+ tokenizer, ids, logprobs, request, segment_offset, time_offset
226
  ):
227
  materialized_segments.append(segment)
228
 
 
254
  device="auto",
255
  dtype="bfloat16",
256
  kv_cache_dtype="fp8",
257
+ enforce_eager=False,
258
  enable_prefix_caching=True,
259
+ max_logprobs=1, # TODO(mfuntowicz) : Set from config?
260
  )
261
  )
262
 
263
  async def transcribe(
264
+ self,
265
+ ctx: Context,
266
+ request: TranscriptionRequest,
267
+ tokenizer: "PreTrainedTokenizer",
268
+ audio_chunks: Iterable[np.ndarray],
269
+ params: "SamplingParams",
270
  ) -> (List[Segment], str):
271
  async def __agenerate__(request_id: str, prompt, params):
272
  """
 
284
  # Wrap tokenizer results with LRU cache to avoid vocabulary lookup
285
  convert_tokens_to_ids = lru_cache(tokenizer.convert_tokens_to_ids)
286
 
 
 
287
  coro_handles = []
288
  for audio_chunk_id, audio_chunk in enumerate(audio_chunks):
289
  # Generate suffixed request-id to submit and identify through vLLM scheduler
 
295
  # Compute initial prompt for the segment
296
  is_verbose = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
297
  language = convert_tokens_to_ids(f"<|{request.language}|>")
298
+ timestamp = convert_tokens_to_ids(
299
+ f"<|0.00|>" if is_verbose else "<|notimestamps|>"
300
+ )
301
+ prompt = create_prompt(
302
+ audio_chunk, WhisperHandler.WHISPER_SAMPLING_RATE, language, timestamp
303
+ )
304
 
305
  # Submit the task
306
+ coro_handles += [
307
+ asyncio.create_task(__agenerate__(request_id, prompt, params))
308
+ ]
309
 
310
  # Wait for all the segment to complete
311
  text_chunks = await asyncio.gather(*coro_handles)
 
317
  return segments, text
318
 
319
  async def __call__(
320
+ self, request: TranscriptionRequest, ctx: Context
321
  ) -> TranscriptionResponse:
322
  with logger.contextualize(request_id=ctx.request_id):
323
  with memoryview(request) as audio:
324
 
325
  # Check if we need to enable the verbose path
326
  is_verbose = (
327
+ request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
328
  )
329
 
330
  # Retrieve the tokenizer and model config asynchronously while we decode audio
 
350
  )
351
 
352
  # Submit audio pieces to the batcher and gather them all
353
+ segments, text = await self.transcribe(
354
+ ctx, request, await tokenizer, audio_chunks, params
355
+ )
356
 
357
  match request.response_kind:
358
  case TranscriptionResponseKind.VERBOSE_JSON: