Fix Gradio API to accept JSON directly for backend compatibility
Browse files
app.py
CHANGED
|
@@ -47,18 +47,11 @@ def predict_single_text(text):
|
|
| 47 |
embedding = get_embedding(text.strip())
|
| 48 |
return f"Embedding (first 10 values): {embedding[:10]}...\nFull embedding has {len(embedding)} dimensions."
|
| 49 |
|
| 50 |
-
def predict_api(
|
| 51 |
-
"""Handle API calls from backend - expects
|
| 52 |
try:
|
| 53 |
-
import json
|
| 54 |
-
data = json.loads(json_str)
|
| 55 |
-
|
| 56 |
-
if not data or 'data' not in data:
|
| 57 |
-
return json.dumps({'error': 'Missing data field'})
|
| 58 |
-
|
| 59 |
-
texts = data['data']
|
| 60 |
if not isinstance(texts, list):
|
| 61 |
-
return
|
| 62 |
|
| 63 |
# Generate embeddings for each text
|
| 64 |
embeddings = []
|
|
@@ -67,17 +60,17 @@ def predict_api(json_str):
|
|
| 67 |
embedding = get_embedding(text)
|
| 68 |
embeddings.append(embedding)
|
| 69 |
else:
|
| 70 |
-
return
|
| 71 |
|
| 72 |
-
return
|
| 73 |
except Exception as e:
|
| 74 |
-
return
|
| 75 |
|
| 76 |
# Create API interface (this will create /api/predict endpoint)
|
| 77 |
api_interface = gr.Interface(
|
| 78 |
fn=predict_api,
|
| 79 |
-
inputs=gr.
|
| 80 |
-
outputs=gr.
|
| 81 |
api_name="predict"
|
| 82 |
)
|
| 83 |
|
|
|
|
| 47 |
embedding = get_embedding(text.strip())
|
| 48 |
return f"Embedding (first 10 values): {embedding[:10]}...\nFull embedding has {len(embedding)} dimensions."
|
| 49 |
|
| 50 |
+
def predict_api(texts):
|
| 51 |
+
"""Handle API calls from backend - expects list of texts directly"""
|
| 52 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
if not isinstance(texts, list):
|
| 54 |
+
return {'error': 'Input must be a list of texts'}
|
| 55 |
|
| 56 |
# Generate embeddings for each text
|
| 57 |
embeddings = []
|
|
|
|
| 60 |
embedding = get_embedding(text)
|
| 61 |
embeddings.append(embedding)
|
| 62 |
else:
|
| 63 |
+
return {'error': 'All items must be strings'}
|
| 64 |
|
| 65 |
+
return {'data': embeddings}
|
| 66 |
except Exception as e:
|
| 67 |
+
return {'error': str(e)}
|
| 68 |
|
| 69 |
# Create API interface (this will create /api/predict endpoint)
|
| 70 |
api_interface = gr.Interface(
|
| 71 |
fn=predict_api,
|
| 72 |
+
inputs=gr.JSON(), # Expects JSON input directly
|
| 73 |
+
outputs=gr.JSON(), # Returns JSON output directly
|
| 74 |
api_name="predict"
|
| 75 |
)
|
| 76 |
|