Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import cohereAPI | |
| # Model configurations | |
| COHERE_MODELS = [ | |
| "command-a-03-2025", | |
| "command-r7b-12-2024", | |
| "command-r-plus-08-2024", | |
| "command-r-08-2024", | |
| "command-light", | |
| "command-light-nightly", | |
| "command", | |
| "command-nightly" | |
| ] | |
| COHERE_LABS_MODELS = [ | |
| "command-a-translate-08-2025", | |
| "command-a-reasoning-08-2025" | |
| ] | |
| def update_model_choices(provider): | |
| """Update model dropdown choices based on selected provider""" | |
| if provider == "Cohere": | |
| return gr.Dropdown(choices=COHERE_MODELS, value=COHERE_MODELS[0]) | |
| elif provider =="Cohere Labs": | |
| return gr.Dropdown(choices=COHERE_LABS_MODELS, value=COHERE_LABS_MODELS[0]) | |
| else: | |
| return gr.Dropdown(choices=[], value=None) | |
| def show_model_change_info(model_name): | |
| """Show info modal when model is changed""" | |
| if model_name: | |
| gr.Info(f"picking up from here with {model_name}") | |
| return model_name | |
| async def respond(message, history, model_name="command-a-03-2025", temperature=0.7, max_tokens=None): | |
| """Generate streaming response using Cohere API""" | |
| # Convert Gradio history format to API format | |
| conversation_history = [] | |
| if history: | |
| for entry in history: | |
| if isinstance(entry, dict): | |
| # Clean dict format - only keep role and content | |
| if "role" in entry and "content" in entry: | |
| conversation_history.append({ | |
| "role": entry["role"], | |
| "content": entry["content"] | |
| }) | |
| elif isinstance(entry, (list, tuple)) and len(entry) == 2: | |
| # Old format: [user_msg, assistant_msg] | |
| user_msg, assistant_msg = entry | |
| if user_msg: | |
| conversation_history.append({"role": "user", "content": str(user_msg)}) | |
| if assistant_msg: | |
| conversation_history.append({"role": "assistant", "content": str(assistant_msg)}) | |
| else: | |
| # Handle other formats gracefully | |
| continue | |
| # Get API key from environment | |
| api_key = os.getenv('COHERE_API_KEY') | |
| if not api_key: | |
| yield "Error: COHERE_API_KEY environment variable not set" | |
| return | |
| # System message for the chatbot | |
| system_message = """You are a helpful AI assistant. Provide concise but complete responses. | |
| Be direct and to the point while ensuring you fully address the user's question or request. | |
| Do not repeat the user's question in your response. Do not exceed 50 words.""" | |
| try: | |
| # Use async streaming function | |
| partial_message = "" | |
| async for chunk in cohereAPI.send_message_stream_async( | |
| system_message=system_message, | |
| user_message=message, | |
| conversation_history=conversation_history, | |
| api_key=api_key, | |
| model_name=model_name, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ): | |
| partial_message += chunk | |
| yield partial_message | |
| except Exception as e: | |
| yield f"Error: {str(e)}" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""## Modular TTS-Chatbot | |
| Status: In Development | |
| The goal of this project is to enable voice-chat with any supported LLM which currently do not have speech ability similar to Gemini or GPT-4o. | |
| """) | |
| # State components to track current values | |
| temperature_state = gr.State(value=0.7) | |
| max_tokens_state = gr.State(value=None) | |
| model_state = gr.State(value=COHERE_MODELS[0]) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Define wrapper function after all components are created | |
| async def chat_wrapper(message, history, model_val, temp_val, tokens_val): | |
| # Use the state values directly | |
| current_model = model_val if model_val else COHERE_MODELS[0] | |
| current_temp = temp_val if temp_val is not None else 0.7 | |
| current_max_tokens = tokens_val | |
| # Stream the response | |
| async for chunk in respond(message, history, current_model, current_temp, current_max_tokens): | |
| yield chunk | |
| # Create chat interface using the wrapper with additional inputs | |
| chat_interface = gr.ChatInterface( | |
| fn=chat_wrapper, | |
| type="messages", | |
| save_history=True, | |
| additional_inputs=[model_state, temperature_state, max_tokens_state] | |
| ) | |
| with gr.Accordion("Chat Settings", elem_id="chat_settings_group"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| provider = gr.Dropdown( | |
| info="Provider", | |
| choices=["Cohere", "Cohere Labs"], | |
| value="Cohere", | |
| elem_id="provider_dropdown", | |
| interactive=True, | |
| show_label=False | |
| ) | |
| model = gr.Dropdown( | |
| info="Model", | |
| choices=COHERE_MODELS, | |
| value=COHERE_MODELS[0], | |
| elem_id="model_dropdown", | |
| interactive=True, | |
| show_label=False | |
| ) | |
| # Set up event handler for provider change | |
| provider.change( | |
| fn=update_model_choices, | |
| inputs=[provider], | |
| outputs=[model] | |
| ) | |
| # Set up event handler for model change | |
| model.change( | |
| fn=show_model_change_info, | |
| inputs=[model], | |
| outputs=[model] | |
| ) | |
| # Update state when model changes | |
| model.change( | |
| fn=lambda x: x, | |
| inputs=[model], | |
| outputs=[model_state] | |
| ) | |
| with gr.Column(scale=1): | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| info="Higher values make output more creative", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.01, | |
| elem_id="temperature_slider", | |
| interactive=True, | |
| ) | |
| max_tokens = gr.Textbox( | |
| label="Max Tokens", | |
| info="Higher values allow longer responses. Leave empty for default.", | |
| value="8192", | |
| elem_id="max_tokens_input", | |
| interactive=True, | |
| show_label=True, | |
| ) | |
| # Update state when temperature changes | |
| temperature.change( | |
| fn=lambda x: x, | |
| inputs=[temperature], | |
| outputs=[temperature_state] | |
| ) | |
| # Update state when max_tokens changes | |
| max_tokens.change( | |
| fn=lambda x: int(x) if x and str(x).strip() else None, | |
| inputs=[max_tokens], | |
| outputs=[max_tokens_state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |