Morgan Funtowicz
commited on
Commit
·
c9543c7
1
Parent(s):
4ff9cbf
feat(whisper): honor language param
Browse files- endpoint.py +26 -17
endpoint.py
CHANGED
|
@@ -64,34 +64,28 @@ def compression_ratio(text: str) -> float:
|
|
| 64 |
def create_prompt(
|
| 65 |
audio: np.ndarray,
|
| 66 |
sampling_rate: int,
|
|
|
|
| 67 |
timestamp_marker: int,
|
| 68 |
-
is_verbose_response: bool,
|
| 69 |
):
|
| 70 |
"""
|
| 71 |
-
|
| 72 |
-
:param audio:
|
| 73 |
-
:param sampling_rate:
|
| 74 |
-
:param
|
| 75 |
-
:param
|
| 76 |
-
:return:
|
| 77 |
"""
|
| 78 |
-
# TODO: We assume english for now
|
| 79 |
-
k_english_token = 50259
|
| 80 |
-
k_timestamp_marker = f"<|{timestamp_marker if is_verbose_response else 0:.2f}|>"
|
| 81 |
-
k_timestamp_marker_token = 50365
|
| 82 |
-
|
| 83 |
return {
|
| 84 |
"encoder_prompt": {
|
| 85 |
"prompt": "",
|
| 86 |
"multi_modal_data": {"audio": (audio, sampling_rate)},
|
| 87 |
},
|
| 88 |
"decoder_prompt": {
|
| 89 |
-
# <|startoftranscript|><|{request.language}|><|transcribe|>{timestamp_marker}
|
| 90 |
"prompt_token_ids": [
|
| 91 |
50258,
|
| 92 |
-
|
| 93 |
50360,
|
| 94 |
-
|
| 95 |
]
|
| 96 |
},
|
| 97 |
}
|
|
@@ -258,11 +252,23 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
|
|
| 258 |
params: "SamplingParams"
|
| 259 |
) -> (List[Segment], str):
|
| 260 |
async def __agenerate__(request_id: str, prompt, params):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
# Submit for inference on the segment & keep track of the background task
|
| 262 |
async for step in self._engine.generate(prompt, params, request_id):
|
| 263 |
pass
|
| 264 |
return step
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
coro_handles = []
|
| 267 |
for audio_chunk_id, audio_chunk in enumerate(audio_chunks):
|
| 268 |
# Generate suffixed request-id to submit and identify through vLLM scheduler
|
|
@@ -272,7 +278,10 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
|
|
| 272 |
timestamp = audio_chunk_id * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC
|
| 273 |
|
| 274 |
# Compute initial prompt for the segment
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
# Submit the task
|
| 278 |
coro_handles += [asyncio.create_task(__agenerate__(request_id, prompt, params))]
|
|
@@ -328,7 +337,7 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
|
|
| 328 |
VerboseTranscription(
|
| 329 |
text=text,
|
| 330 |
duration=get_duration(y=waveform, sr=sampling),
|
| 331 |
-
language=
|
| 332 |
segments=segments,
|
| 333 |
# word=None
|
| 334 |
)
|
|
|
|
| 64 |
def create_prompt(
|
| 65 |
audio: np.ndarray,
|
| 66 |
sampling_rate: int,
|
| 67 |
+
language: int,
|
| 68 |
timestamp_marker: int,
|
|
|
|
| 69 |
):
|
| 70 |
"""
|
| 71 |
+
Generate the right prompt with the specific parameters to submit for inference over Whisper
|
| 72 |
+
:param audio: PCM data containing audio signal representation
|
| 73 |
+
:param sampling_rate: Number of samples in one second of audio
|
| 74 |
+
:param language: Token id representing the language of the audio content
|
| 75 |
+
:param timestamp_marker: Token id representing the temporal position within the audio content for this segment
|
| 76 |
+
:return: Dictionary with all the prefilled value to call `generate`
|
| 77 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
return {
|
| 79 |
"encoder_prompt": {
|
| 80 |
"prompt": "",
|
| 81 |
"multi_modal_data": {"audio": (audio, sampling_rate)},
|
| 82 |
},
|
| 83 |
"decoder_prompt": {
|
|
|
|
| 84 |
"prompt_token_ids": [
|
| 85 |
50258,
|
| 86 |
+
language,
|
| 87 |
50360,
|
| 88 |
+
timestamp_marker,
|
| 89 |
]
|
| 90 |
},
|
| 91 |
}
|
|
|
|
| 252 |
params: "SamplingParams"
|
| 253 |
) -> (List[Segment], str):
|
| 254 |
async def __agenerate__(request_id: str, prompt, params):
|
| 255 |
+
"""
|
| 256 |
+
Helper method to unroll asynchronous generator and return the last element
|
| 257 |
+
:param request_id: Unique identifier for this request
|
| 258 |
+
:param prompt: The prompt to submit for inference on vLLM through `generate(...)`
|
| 259 |
+
:param params: The parameters passed along with the prompt for inference on vLLM through `generate(...)`
|
| 260 |
+
:return: `CompletionOutput`
|
| 261 |
+
"""
|
| 262 |
# Submit for inference on the segment & keep track of the background task
|
| 263 |
async for step in self._engine.generate(prompt, params, request_id):
|
| 264 |
pass
|
| 265 |
return step
|
| 266 |
|
| 267 |
+
# Wrap tokenizer results with LRU cache to avoid vocabulary lookup
|
| 268 |
+
convert_tokens_to_ids = lru_cache(tokenizer.convert_tokens_to_ids)
|
| 269 |
+
|
| 270 |
+
# f"<|{timestamp_marker if is_verbose_response else 0:.2f}|>"
|
| 271 |
+
|
| 272 |
coro_handles = []
|
| 273 |
for audio_chunk_id, audio_chunk in enumerate(audio_chunks):
|
| 274 |
# Generate suffixed request-id to submit and identify through vLLM scheduler
|
|
|
|
| 278 |
timestamp = audio_chunk_id * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC
|
| 279 |
|
| 280 |
# Compute initial prompt for the segment
|
| 281 |
+
is_verbose = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
|
| 282 |
+
language = convert_tokens_to_ids(f"<|{request.language}|>")
|
| 283 |
+
timestamp = convert_tokens_to_ids(f"<|{timestamp:.2f}|>" if is_verbose else '<|notimestamps|>')
|
| 284 |
+
prompt = create_prompt(audio_chunk, WhisperHandler.WHISPER_SAMPLING_RATE, language, timestamp)
|
| 285 |
|
| 286 |
# Submit the task
|
| 287 |
coro_handles += [asyncio.create_task(__agenerate__(request_id, prompt, params))]
|
|
|
|
| 337 |
VerboseTranscription(
|
| 338 |
text=text,
|
| 339 |
duration=get_duration(y=waveform, sr=sampling),
|
| 340 |
+
language=request.language,
|
| 341 |
segments=segments,
|
| 342 |
# word=None
|
| 343 |
)
|