Arastun commited on
Commit
f01b23a
·
verified ·
1 Parent(s): f3b18f8

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +69 -0
handler.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+
4
+ class EndpointHandler:
5
+ def __init__(self, path=""):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left')
7
+ self.model = AutoModelForCausalLM.from_pretrained(
8
+ path,
9
+ torch_dtype=torch.float16,
10
+ device_map="auto"
11
+ )
12
+ self.model.eval()
13
+
14
+ # Qwen3-Reranker uses a specific token to extract the relevance score
15
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
16
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
17
+
18
+ def format_input(self, query, document, instruction=None):
19
+ if instruction is None:
20
+ instruction = "Given a web search query, retrieve relevant passages that answer the query"
21
+ prefix = f"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct. Note only output a single token from [yes, no] after thinking.\n<|im_end|>\n<|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}\n<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n"
22
+ return prefix
23
+
24
+ def __call__(self, data: dict) -> dict:
25
+ """
26
+ Expected input format:
27
+ {
28
+ "query": "what is the capital of France",
29
+ "documents": ["Paris is the capital...", "London is the capital..."],
30
+ "instruction": "optional custom instruction"
31
+ }
32
+ """
33
+ inputs = data.get("inputs", data) # unwrap HF gateway nesting
34
+ query = inputs.get("query", "")
35
+ documents = inputs.get("documents", [])
36
+ instruction = inputs.get("instruction", None)
37
+
38
+ if not query or not documents:
39
+ return {"error": "Must provide 'query' and 'documents'"}
40
+
41
+ prompts = [self.format_input(query, doc, instruction) for doc in documents]
42
+
43
+ inputs = self.tokenizer(
44
+ prompts,
45
+ return_tensors="pt",
46
+ padding=True,
47
+ truncation=True,
48
+ max_length=4096
49
+ ).to(self.model.device)
50
+
51
+ with torch.no_grad():
52
+ outputs = self.model(**inputs)
53
+ # Get logits for the final token position
54
+ logits = outputs.logits[:, -1, :]
55
+ # Score is the softmax probability of "yes" vs "no"
56
+ true_logits = logits[:, self.token_true_id]
57
+ false_logits = logits[:, self.token_false_id]
58
+ scores = torch.softmax(
59
+ torch.stack([false_logits, true_logits], dim=1), dim=1
60
+ )[:, 1].tolist()
61
+
62
+ return {
63
+ "scores": scores,
64
+ "ranking": sorted(
65
+ range(len(documents)),
66
+ key=lambda i: scores[i],
67
+ reverse=True
68
+ )
69
+ }