Morgan Funtowicz commited on
Commit
c9543c7
·
1 Parent(s): 4ff9cbf

feat(whisper): honor language param

Browse files
Files changed (1) hide show
  1. endpoint.py +26 -17
endpoint.py CHANGED
@@ -64,34 +64,28 @@ def compression_ratio(text: str) -> float:
64
  def create_prompt(
65
  audio: np.ndarray,
66
  sampling_rate: int,
 
67
  timestamp_marker: int,
68
- is_verbose_response: bool,
69
  ):
70
  """
71
-
72
- :param audio:
73
- :param sampling_rate:
74
- :param timestamp_marker:
75
- :param is_verbose_response:
76
- :return:
77
  """
78
- # TODO: We assume english for now
79
- k_english_token = 50259
80
- k_timestamp_marker = f"<|{timestamp_marker if is_verbose_response else 0:.2f}|>"
81
- k_timestamp_marker_token = 50365
82
-
83
  return {
84
  "encoder_prompt": {
85
  "prompt": "",
86
  "multi_modal_data": {"audio": (audio, sampling_rate)},
87
  },
88
  "decoder_prompt": {
89
- # <|startoftranscript|><|{request.language}|><|transcribe|>{timestamp_marker}
90
  "prompt_token_ids": [
91
  50258,
92
- k_english_token,
93
  50360,
94
- k_timestamp_marker_token,
95
  ]
96
  },
97
  }
@@ -258,11 +252,23 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
258
  params: "SamplingParams"
259
  ) -> (List[Segment], str):
260
  async def __agenerate__(request_id: str, prompt, params):
 
 
 
 
 
 
 
261
  # Submit for inference on the segment & keep track of the background task
262
  async for step in self._engine.generate(prompt, params, request_id):
263
  pass
264
  return step
265
 
 
 
 
 
 
266
  coro_handles = []
267
  for audio_chunk_id, audio_chunk in enumerate(audio_chunks):
268
  # Generate suffixed request-id to submit and identify through vLLM scheduler
@@ -272,7 +278,10 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
272
  timestamp = audio_chunk_id * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC
273
 
274
  # Compute initial prompt for the segment
275
- prompt = create_prompt(audio_chunk, WhisperHandler.WHISPER_SAMPLING_RATE, timestamp, request)
 
 
 
276
 
277
  # Submit the task
278
  coro_handles += [asyncio.create_task(__agenerate__(request_id, prompt, params))]
@@ -328,7 +337,7 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
328
  VerboseTranscription(
329
  text=text,
330
  duration=get_duration(y=waveform, sr=sampling),
331
- language="en",
332
  segments=segments,
333
  # word=None
334
  )
 
64
  def create_prompt(
65
  audio: np.ndarray,
66
  sampling_rate: int,
67
+ language: int,
68
  timestamp_marker: int,
 
69
  ):
70
  """
71
+ Generate the right prompt with the specific parameters to submit for inference over Whisper
72
+ :param audio: PCM data containing audio signal representation
73
+ :param sampling_rate: Number of samples in one second of audio
74
+ :param language: Token id representing the language of the audio content
75
+ :param timestamp_marker: Token id representing the temporal position within the audio content for this segment
76
+ :return: Dictionary with all the prefilled value to call `generate`
77
  """
 
 
 
 
 
78
  return {
79
  "encoder_prompt": {
80
  "prompt": "",
81
  "multi_modal_data": {"audio": (audio, sampling_rate)},
82
  },
83
  "decoder_prompt": {
 
84
  "prompt_token_ids": [
85
  50258,
86
+ language,
87
  50360,
88
+ timestamp_marker,
89
  ]
90
  },
91
  }
 
252
  params: "SamplingParams"
253
  ) -> (List[Segment], str):
254
  async def __agenerate__(request_id: str, prompt, params):
255
+ """
256
+ Helper method to unroll asynchronous generator and return the last element
257
+ :param request_id: Unique identifier for this request
258
+ :param prompt: The prompt to submit for inference on vLLM through `generate(...)`
259
+ :param params: The parameters passed along with the prompt for inference on vLLM through `generate(...)`
260
+ :return: `CompletionOutput`
261
+ """
262
  # Submit for inference on the segment & keep track of the background task
263
  async for step in self._engine.generate(prompt, params, request_id):
264
  pass
265
  return step
266
 
267
+ # Wrap tokenizer results with LRU cache to avoid vocabulary lookup
268
+ convert_tokens_to_ids = lru_cache(tokenizer.convert_tokens_to_ids)
269
+
270
+ # f"<|{timestamp_marker if is_verbose_response else 0:.2f}|>"
271
+
272
  coro_handles = []
273
  for audio_chunk_id, audio_chunk in enumerate(audio_chunks):
274
  # Generate suffixed request-id to submit and identify through vLLM scheduler
 
278
  timestamp = audio_chunk_id * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC
279
 
280
  # Compute initial prompt for the segment
281
+ is_verbose = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
282
+ language = convert_tokens_to_ids(f"<|{request.language}|>")
283
+ timestamp = convert_tokens_to_ids(f"<|{timestamp:.2f}|>" if is_verbose else '<|notimestamps|>')
284
+ prompt = create_prompt(audio_chunk, WhisperHandler.WHISPER_SAMPLING_RATE, language, timestamp)
285
 
286
  # Submit the task
287
  coro_handles += [asyncio.create_task(__agenerate__(request_id, prompt, params))]
 
337
  VerboseTranscription(
338
  text=text,
339
  duration=get_duration(y=waveform, sr=sampling),
340
+ language=request.language,
341
  segments=segments,
342
  # word=None
343
  )