Spaces:
Sleeping
Sleeping
Update cohereAPI.py
Browse files- cohereAPI.py +32 -14
cohereAPI.py
CHANGED
|
@@ -11,7 +11,7 @@ def get_client(api_key):
|
|
| 11 |
_client = cohere.ClientV2(api_key)
|
| 12 |
return _client
|
| 13 |
|
| 14 |
-
def send_message_stream(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
|
| 15 |
"""Stream response from Cohere API"""
|
| 16 |
# Get or create the Cohere client
|
| 17 |
co = get_client(api_key)
|
|
@@ -21,11 +21,19 @@ def send_message_stream(system_message, user_message, conversation_history, api_
|
|
| 21 |
messages.extend(conversation_history)
|
| 22 |
messages.append({"role": "user", "content": user_message})
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
# Send streaming request to Cohere
|
| 25 |
-
stream = co.chat_stream(
|
| 26 |
-
model=model_name,
|
| 27 |
-
messages=messages
|
| 28 |
-
)
|
| 29 |
|
| 30 |
# Collect full response for history
|
| 31 |
full_response = ""
|
|
@@ -37,10 +45,10 @@ def send_message_stream(system_message, user_message, conversation_history, api_
|
|
| 37 |
full_response += text_chunk
|
| 38 |
yield text_chunk
|
| 39 |
|
| 40 |
-
async def send_message_stream_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
|
| 41 |
"""Async wrapper for streaming response from Cohere API"""
|
| 42 |
def _sync_stream():
|
| 43 |
-
return send_message_stream(system_message, user_message, conversation_history, api_key, model_name)
|
| 44 |
|
| 45 |
# Run the synchronous generator in a thread
|
| 46 |
loop = asyncio.get_event_loop()
|
|
@@ -72,7 +80,7 @@ async def send_message_stream_async(system_message, user_message, conversation_h
|
|
| 72 |
yield chunk
|
| 73 |
|
| 74 |
|
| 75 |
-
def send_message(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
|
| 76 |
"""Non-streaming version for backward compatibility"""
|
| 77 |
# Get or create the Cohere client
|
| 78 |
co = get_client(api_key)
|
|
@@ -82,18 +90,26 @@ def send_message(system_message, user_message, conversation_history, api_key, mo
|
|
| 82 |
messages.extend(conversation_history)
|
| 83 |
messages.append({"role": "user", "content": user_message})
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# Send request to Cohere synchronously
|
| 86 |
-
response = co.chat(
|
| 87 |
-
model=model_name,
|
| 88 |
-
messages=messages
|
| 89 |
-
)
|
| 90 |
|
| 91 |
# Get the response
|
| 92 |
response_content = response.message.content[0].text
|
| 93 |
|
| 94 |
return response_content
|
| 95 |
|
| 96 |
-
async def send_message_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
|
| 97 |
"""Async version using asyncio.to_thread"""
|
| 98 |
return await asyncio.to_thread(
|
| 99 |
send_message,
|
|
@@ -101,5 +117,7 @@ async def send_message_async(system_message, user_message, conversation_history,
|
|
| 101 |
user_message,
|
| 102 |
conversation_history,
|
| 103 |
api_key,
|
| 104 |
-
model_name
|
|
|
|
|
|
|
| 105 |
)
|
|
|
|
| 11 |
_client = cohere.ClientV2(api_key)
|
| 12 |
return _client
|
| 13 |
|
| 14 |
+
def send_message_stream(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None):
|
| 15 |
"""Stream response from Cohere API"""
|
| 16 |
# Get or create the Cohere client
|
| 17 |
co = get_client(api_key)
|
|
|
|
| 21 |
messages.extend(conversation_history)
|
| 22 |
messages.append({"role": "user", "content": user_message})
|
| 23 |
|
| 24 |
+
# Prepare chat parameters
|
| 25 |
+
chat_params = {
|
| 26 |
+
"model": model_name,
|
| 27 |
+
"messages": messages,
|
| 28 |
+
"temperature": temperature
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
# Add max_tokens if specified
|
| 32 |
+
if max_tokens:
|
| 33 |
+
chat_params["max_tokens"] = int(max_tokens)
|
| 34 |
+
|
| 35 |
# Send streaming request to Cohere
|
| 36 |
+
stream = co.chat_stream(**chat_params)
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
# Collect full response for history
|
| 39 |
full_response = ""
|
|
|
|
| 45 |
full_response += text_chunk
|
| 46 |
yield text_chunk
|
| 47 |
|
| 48 |
+
async def send_message_stream_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None):
|
| 49 |
"""Async wrapper for streaming response from Cohere API"""
|
| 50 |
def _sync_stream():
|
| 51 |
+
return send_message_stream(system_message, user_message, conversation_history, api_key, model_name, temperature, max_tokens)
|
| 52 |
|
| 53 |
# Run the synchronous generator in a thread
|
| 54 |
loop = asyncio.get_event_loop()
|
|
|
|
| 80 |
yield chunk
|
| 81 |
|
| 82 |
|
| 83 |
+
def send_message(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None):
|
| 84 |
"""Non-streaming version for backward compatibility"""
|
| 85 |
# Get or create the Cohere client
|
| 86 |
co = get_client(api_key)
|
|
|
|
| 90 |
messages.extend(conversation_history)
|
| 91 |
messages.append({"role": "user", "content": user_message})
|
| 92 |
|
| 93 |
+
# Prepare chat parameters
|
| 94 |
+
chat_params = {
|
| 95 |
+
"model": model_name,
|
| 96 |
+
"messages": messages,
|
| 97 |
+
"temperature": temperature
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Add max_tokens if specified
|
| 101 |
+
if max_tokens:
|
| 102 |
+
chat_params["max_tokens"] = int(max_tokens)
|
| 103 |
+
|
| 104 |
# Send request to Cohere synchronously
|
| 105 |
+
response = co.chat(**chat_params)
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
# Get the response
|
| 108 |
response_content = response.message.content[0].text
|
| 109 |
|
| 110 |
return response_content
|
| 111 |
|
| 112 |
+
async def send_message_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None):
|
| 113 |
"""Async version using asyncio.to_thread"""
|
| 114 |
return await asyncio.to_thread(
|
| 115 |
send_message,
|
|
|
|
| 117 |
user_message,
|
| 118 |
conversation_history,
|
| 119 |
api_key,
|
| 120 |
+
model_name,
|
| 121 |
+
temperature,
|
| 122 |
+
max_tokens
|
| 123 |
)
|