Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -74,21 +74,24 @@ class DecodeResponse(BaseModel):
74
  def embed(req: EmbedRequest):
75
  text = req.text
76
 
77
- # Case 1: Query β†’ pooled mean of multivectors
78
- if not req.return_token_embeddings:
 
 
79
  with torch.no_grad():
80
  outputs = model.encode_text(
81
  texts=[text],
82
  task=req.task,
83
- prompt_name=req.prompt_name or "query",
84
- return_multivector=req.return_token_embeddings,
85
  truncate_dim=req.truncate_dim,
86
  )
87
- # outputs[0] = (num_vectors, hidden_dim)
88
- pooled = outputs[0].mean(dim=0).cpu()
89
- return {"embeddings": [pooled]}
90
 
 
91
  # Case 2: Passage β†’ sliding window, token-level embeddings
 
92
  enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
93
  input_ids = enc["input_ids"].squeeze(0).to(device)
94
  total_tokens = input_ids.size(0)
@@ -106,8 +109,8 @@ def embed(req: EmbedRequest):
106
  outputs = model.encode_text(
107
  texts=[tokenizer.decode(window_ids[0])],
108
  task=req.task,
109
- prompt_name=req.prompt_name or "passage",
110
- return_multivector=req.return_token_embeddings,
111
  truncate_dim=req.truncate_dim,
112
  )
113
 
@@ -119,10 +122,9 @@ def embed(req: EmbedRequest):
119
  embeddings.append(window_embeds)
120
  position += max_len - stride
121
 
122
- full_embeddings = torch.cat(embeddings, dim=0)
123
  return {"embeddings": full_embeddings}
124
 
125
-
126
  # -----------------------------
127
  # Embedding Endpoint (image)
128
  # -----------------------------
 
74
  def embed(req: EmbedRequest):
75
  text = req.text
76
 
77
+ # -----------------------------
78
+ # Case 1: Query β†’ mean pool across token embeddings
79
+ # -----------------------------
80
+ if (req.prompt_name or "").lower() == "query":
81
  with torch.no_grad():
82
  outputs = model.encode_text(
83
  texts=[text],
84
  task=req.task,
85
+ prompt_name="query",
86
+ return_multivector=True, # always token-level
87
  truncate_dim=req.truncate_dim,
88
  )
89
+ pooled = outputs[0].mean(dim=0).cpu().tolist()
90
+ return {"embeddings": [pooled]} # wrap in batch dimension
 
91
 
92
+ # -----------------------------
93
  # Case 2: Passage β†’ sliding window, token-level embeddings
94
+ # -----------------------------
95
  enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
96
  input_ids = enc["input_ids"].squeeze(0).to(device)
97
  total_tokens = input_ids.size(0)
 
109
  outputs = model.encode_text(
110
  texts=[tokenizer.decode(window_ids[0])],
111
  task=req.task,
112
+ prompt_name="passage",
113
+ return_multivector=True, # always token-level
114
  truncate_dim=req.truncate_dim,
115
  )
116
 
 
122
  embeddings.append(window_embeds)
123
  position += max_len - stride
124
 
125
+ full_embeddings = torch.cat(embeddings, dim=0).tolist()
126
  return {"embeddings": full_embeddings}
127
 
 
128
  # -----------------------------
129
  # Embedding Endpoint (image)
130
  # -----------------------------