| from typing import Dict |
| from pyannote.audio import Pipeline |
| from io import BytesIO |
| import torch |
| import torchaudio |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| self.pipeline = Pipeline.from_pretrained("config.yaml") |
|
|
| def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]: |
| """ |
| Args: |
| data (:obj:): |
| includes the deserialized audio file as bytes |
| Return: |
| A :obj:`dict`:. base64 encoded image |
| """ |
| |
| inputs = data.pop("inputs", data) |
| parameters = data.pop("parameters", None) |
|
|
| waveform, sample_rate = torchaudio.load(BytesIO(inputs)) |
| pyannote_input = {"waveform": waveform, "sample_rate": sample_rate} |
|
|
| |
| |
| if parameters is not None: |
| diarization = self.pipeline(pyannote_input, **parameters) |
| else: |
| diarization = self.pipeline(pyannote_input) |
|
|
| |
| processed_diarization = [ |
| {"label": str(label), "start": str(segment.start), "stop": str(segment.end)} |
| for segment, _, label in diarization.itertracks(yield_label=True) |
| ] |
|
|
| return {"diarization": processed_diarization} |
|
|