import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer import logging import json import os import numpy as np import re logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class KiswahiliChatbot: def __init__(self, model_path="./trained_bert_model", device=None, threshold=0.6): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Load model if os.path.exists(model_path): self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForSequenceClassification.from_pretrained(model_path) logger.info("✅ Trained model loaded!") else: raise FileNotFoundError(f"{model_path} not found. Please train the model first.") self.model.to(self.device) self.model.eval() self.threshold = threshold # minimum probability to accept a response # Load responses self.responses = self._load_response_bank() logger.info(f"📋 Loaded {len(self.responses)} responses") def _load_response_bank(self): response_file = "./trained_bert_model/responses.json" if os.path.exists(response_file): with open(response_file, 'r', encoding='utf-8') as f: data = json.load(f) return data.get("responses", []) return [] def _clean_text(self, text: str) -> str: text = re.sub(r'[^\w\s?]', '', text) return ' '.join(text.split()).lower() def chat(self, user_input: str) -> str: user_input_clean = self._clean_text(user_input) if not user_input_clean: return "Tafadhali andika ujumbe." best_response = None best_score = 0.0 for response in self.responses: combined_text = f"{user_input_clean} [SEP] {response}" inputs = self.tokenizer(combined_text, return_tensors="pt", truncation=True, max_length=256, padding=True).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) probs = torch.softmax(outputs.logits, dim=1) score = probs[0][1].item() # probability of being the correct response if score > best_score: best_score = score best_response = response if best_score < self.threshold: return "Samahani, sielewi. Unaweza kuuliza kwa njia nyingine?" # Capitalize first letter and ensure punctuation best_response = best_response.strip() if best_response and not best_response.endswith(('.', '!', '?')): best_response += '.' best_response = best_response[0].upper() + best_response[1:] return best_response