Update app.py
#5
by
Amlan99
- opened
app.py
CHANGED
@@ -74,21 +74,24 @@ class DecodeResponse(BaseModel):
|
|
74 |
def embed(req: EmbedRequest):
|
75 |
text = req.text
|
76 |
|
77 |
-
#
|
78 |
-
|
|
|
|
|
79 |
with torch.no_grad():
|
80 |
outputs = model.encode_text(
|
81 |
texts=[text],
|
82 |
task=req.task,
|
83 |
-
prompt_name=
|
84 |
-
return_multivector=
|
85 |
truncate_dim=req.truncate_dim,
|
86 |
)
|
87 |
-
|
88 |
-
|
89 |
-
return {"embeddings": [pooled]}
|
90 |
|
|
|
91 |
# Case 2: Passage β sliding window, token-level embeddings
|
|
|
92 |
enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
|
93 |
input_ids = enc["input_ids"].squeeze(0).to(device)
|
94 |
total_tokens = input_ids.size(0)
|
@@ -106,8 +109,8 @@ def embed(req: EmbedRequest):
|
|
106 |
outputs = model.encode_text(
|
107 |
texts=[tokenizer.decode(window_ids[0])],
|
108 |
task=req.task,
|
109 |
-
prompt_name=
|
110 |
-
return_multivector=
|
111 |
truncate_dim=req.truncate_dim,
|
112 |
)
|
113 |
|
@@ -119,10 +122,9 @@ def embed(req: EmbedRequest):
|
|
119 |
embeddings.append(window_embeds)
|
120 |
position += max_len - stride
|
121 |
|
122 |
-
full_embeddings = torch.cat(embeddings, dim=0)
|
123 |
return {"embeddings": full_embeddings}
|
124 |
|
125 |
-
|
126 |
# -----------------------------
|
127 |
# Embedding Endpoint (image)
|
128 |
# -----------------------------
|
|
|
74 |
def embed(req: EmbedRequest):
|
75 |
text = req.text
|
76 |
|
77 |
+
# -----------------------------
|
78 |
+
# Case 1: Query β mean pool across token embeddings
|
79 |
+
# -----------------------------
|
80 |
+
if (req.prompt_name or "").lower() == "query":
|
81 |
with torch.no_grad():
|
82 |
outputs = model.encode_text(
|
83 |
texts=[text],
|
84 |
task=req.task,
|
85 |
+
prompt_name="query",
|
86 |
+
return_multivector=True, # always token-level
|
87 |
truncate_dim=req.truncate_dim,
|
88 |
)
|
89 |
+
pooled = outputs[0].mean(dim=0).cpu().tolist()
|
90 |
+
return {"embeddings": [pooled]} # wrap in batch dimension
|
|
|
91 |
|
92 |
+
# -----------------------------
|
93 |
# Case 2: Passage β sliding window, token-level embeddings
|
94 |
+
# -----------------------------
|
95 |
enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
|
96 |
input_ids = enc["input_ids"].squeeze(0).to(device)
|
97 |
total_tokens = input_ids.size(0)
|
|
|
109 |
outputs = model.encode_text(
|
110 |
texts=[tokenizer.decode(window_ids[0])],
|
111 |
task=req.task,
|
112 |
+
prompt_name="passage",
|
113 |
+
return_multivector=True, # always token-level
|
114 |
truncate_dim=req.truncate_dim,
|
115 |
)
|
116 |
|
|
|
122 |
embeddings.append(window_embeds)
|
123 |
position += max_len - stride
|
124 |
|
125 |
+
full_embeddings = torch.cat(embeddings, dim=0).tolist()
|
126 |
return {"embeddings": full_embeddings}
|
127 |
|
|
|
128 |
# -----------------------------
|
129 |
# Embedding Endpoint (image)
|
130 |
# -----------------------------
|