Update handler.py
Browse files- handler.py +47 -25
handler.py
CHANGED
@@ -1,43 +1,65 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
from typing import Dict, List, Any
|
4 |
from llama_cpp import Llama
|
5 |
-
import gemma_tools as gem
|
6 |
|
7 |
MAX_TOKENS = 8192
|
8 |
|
|
|
9 |
class EndpointHandler():
|
10 |
-
def __init__(self
|
11 |
-
#
|
12 |
-
|
|
|
|
|
|
|
13 |
|
14 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
16 |
fmat = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{prompt} <endofturn>\n<startofturn>model"
|
17 |
-
|
18 |
-
if
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
try:
|
24 |
-
fmat = fmat.format(system_prompt=args
|
|
|
|
|
25 |
except Exception as e:
|
26 |
-
|
|
|
27 |
"status": "error",
|
28 |
-
"reason": "
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
32 |
try:
|
33 |
max_length = int(max_length)
|
|
|
34 |
except Exception as e:
|
35 |
-
|
|
|
36 |
"status": "error",
|
37 |
-
"reason": "max_length was passed as something that was
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
return
|
|
|
|
|
|
|
43 |
|
|
|
|
|
|
|
1 |
from typing import Dict, List, Any
|
2 |
from llama_cpp import Llama
|
|
|
3 |
|
4 |
MAX_TOKENS = 8192
|
5 |
|
6 |
+
|
7 |
class EndpointHandler():
|
8 |
+
def __init__(self):
|
9 |
+
# Initialize the model with your ComicBot configuration
|
10 |
+
print("Initializing Llama model with ComicBot settings...")
|
11 |
+
self.model = Llama.from_pretrained(
|
12 |
+
"njwright92/ComicBot_v.2-gguf", filename="comic_mistral-v5.2.q5_0.gguf", n_ctx=8192)
|
13 |
+
print("Model initialization complete.")
|
14 |
|
15 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
16 |
+
# Extract arguments from the data
|
17 |
+
print("Extracting arguments from the data payload...")
|
18 |
+
args = data.get("args", {})
|
19 |
+
print(f"Arguments extracted: {args}")
|
20 |
+
|
21 |
+
# Define the formatting template
|
22 |
fmat = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{prompt} <endofturn>\n<startofturn>model"
|
23 |
+
|
24 |
+
# Check if args is properly formatted
|
25 |
+
if not args:
|
26 |
+
print("No arguments found in the data payload.")
|
27 |
+
return [{
|
28 |
+
"status": "error",
|
29 |
+
"message": "No arguments found in the data payload."
|
30 |
+
}]
|
31 |
+
|
32 |
try:
|
33 |
+
fmat = fmat.format(system_prompt=args.get(
|
34 |
+
"system_prompt", ""), prompt=args.get("inputs", ""))
|
35 |
+
print(f"Formatted prompt: {fmat}")
|
36 |
except Exception as e:
|
37 |
+
print(f"Error in formatting the prompt: {str(e)}")
|
38 |
+
return [{
|
39 |
"status": "error",
|
40 |
+
"reason": "Invalid format",
|
41 |
+
"detail": str(e)
|
42 |
+
}]
|
43 |
+
|
44 |
+
max_length = data.get("max_length", 512)
|
45 |
try:
|
46 |
max_length = int(max_length)
|
47 |
+
print(f"Max length set to: {max_length}")
|
48 |
except Exception as e:
|
49 |
+
print(f"Error converting max_length to int: {str(e)}")
|
50 |
+
return [{
|
51 |
"status": "error",
|
52 |
+
"reason": "max_length was passed as something that was not a plain old int",
|
53 |
+
"detail": str(e)
|
54 |
+
}]
|
55 |
+
|
56 |
+
print("Generating response from the model...")
|
57 |
+
res = self.model(fmat, temperature=args.get("temperature", 1.0), top_p=args.get(
|
58 |
+
"top_p", 0.9), top_k=args.get("top_k", 40), max_tokens=max_length)
|
59 |
+
print(f"Model response: {res}")
|
60 |
|
61 |
+
return [{
|
62 |
+
"status": "success",
|
63 |
+
"response": res
|
64 |
+
}]
|
65 |
|