Commit 
							
							·
						
						7f5eb0f
	
1
								Parent(s):
							
							a1d11a0
								
Add hf endpoint handler.py (#24)
Browse files- Add hf endpoint handler.py (8d5a103212eb614fc3f958ea454e6118fd5f811f)
Co-authored-by: Olivier Dehaene <[email protected]>
- handler.py +33 -0
    	
        handler.py
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from typing import Any, Dict
         | 
| 4 | 
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class EndpointHandler:
         | 
| 8 | 
            +
                def __init__(self, path=""):
         | 
| 9 | 
            +
                    # load model and tokenizer from path
         | 
| 10 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(path)
         | 
| 11 | 
            +
                    self.model = AutoModelForCausalLM.from_pretrained(
         | 
| 12 | 
            +
                        path, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True
         | 
| 13 | 
            +
                    )
         | 
| 14 | 
            +
                    self.device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
         | 
| 17 | 
            +
                    # process input
         | 
| 18 | 
            +
                    inputs = data.pop("inputs", data)
         | 
| 19 | 
            +
                    parameters = data.pop("parameters", None)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    # preprocess
         | 
| 22 | 
            +
                    inputs = self.tokenizer(inputs, return_tensors="pt").to(self.device)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    # pass inputs with all kwargs in data
         | 
| 25 | 
            +
                    if parameters is not None:
         | 
| 26 | 
            +
                        outputs = self.model.generate(**inputs, **parameters)
         | 
| 27 | 
            +
                    else:
         | 
| 28 | 
            +
                        outputs = self.model.generate(**inputs)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    # postprocess the prediction
         | 
| 31 | 
            +
                    prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    return [{"generated_text": prediction}]
         | 

 
		