Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import importlib.util | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # Download and import model components from HF Hub | |
| model_repo = "elapt1c/hrom-testing" | |
| # 1. Import trainer module components | |
| trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM_Trainer.py") | |
| spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file) | |
| trainer_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(trainer_module) | |
| HROM = trainer_module.HROM | |
| CONFIG = trainer_module.CONFIG | |
| SafetyManager = trainer_module.SafetyManager | |
| # 2. Load tokenizer | |
| tokenizer_file = hf_hub_download(repo_id=model_repo, filename="hrom_tokenizer.json") | |
| tokenizer = Tokenizer.from_file(tokenizer_file) | |
| # 3. Load model checkpoint | |
| checkpoint_file = hf_hub_download(repo_id=model_repo, filename="hrom.pt") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(): | |
| model = HROM().to(device) | |
| checkpoint = torch.load(checkpoint_file, map_location=device) | |
| model.load_state_dict(checkpoint['model']) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| safety = SafetyManager(model, tokenizer) | |
| max_response_length = 200 | |
| def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200): | |
| device = next(model.parameters()).device | |
| generated_ids = input_ids.copy() | |
| for _ in range(max_length): | |
| input_tensor = torch.tensor([generated_ids], device=device) | |
| with torch.no_grad(): | |
| logits = model(input_tensor) | |
| next_token = logits.argmax(-1)[:, -1].item() | |
| if next_token == tokenizer.token_to_id("</s>"): | |
| break | |
| current_text = tokenizer.decode(generated_ids + [next_token]) | |
| if not safety_manager.content_filter(current_text): | |
| break | |
| generated_ids.append(next_token) | |
| return generated_ids[len(input_ids):] | |
| def process_message(user_input, chat_history, token_history): | |
| # Process user input | |
| user_turn = f"<user> {user_input} </s>" | |
| user_tokens = tokenizer.encode(user_turn).ids | |
| token_history.extend(user_tokens) | |
| # Prepare input sequence | |
| input_sequence = [tokenizer.token_to_id("<s>")] + token_history | |
| # Truncate if needed | |
| max_input_len = CONFIG["max_seq_len"] - max_response_length | |
| if len(input_sequence) > max_input_len: | |
| input_sequence = input_sequence[-max_input_len:] | |
| token_history = input_sequence[1:] | |
| # Generate response | |
| response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length) | |
| # Process assistant response | |
| assistant_text = "I couldn't generate a proper response." | |
| if response_ids: | |
| if response_ids[0] == tokenizer.token_to_id("<assistant>"): | |
| try: | |
| end_idx = response_ids.index(tokenizer.token_to_id("</s>")) | |
| assistant_text = tokenizer.decode(response_ids[1:end_idx]) | |
| token_history.extend(response_ids[:end_idx+1]) | |
| except ValueError: | |
| assistant_text = tokenizer.decode(response_ids[1:]) | |
| token_history.extend(response_ids) | |
| else: | |
| assistant_text = tokenizer.decode(response_ids) | |
| token_history.extend(response_ids) | |
| chat_history.append((user_input, assistant_text)) | |
| return chat_history, token_history | |
| def clear_history(): | |
| return [], [] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# HROM-V1 Chatbot") | |
| chatbot = gr.Chatbot(height=500) | |
| msg = gr.Textbox(label="Your Message") | |
| token_state = gr.State([]) | |
| msg.submit( | |
| process_message, | |
| [msg, chatbot, token_state], | |
| [chatbot, token_state], | |
| queue=False | |
| ).then( | |
| lambda: "", None, msg | |
| ) | |
| clear_btn = gr.Button("Clear Chat History") | |
| clear_btn.click( | |
| clear_history, | |
| outputs=[chatbot, token_state], | |
| queue=False | |
| ) | |
| demo.launch() |