File size: 3,460 Bytes
1dfe909
812387d
 
 
f03b6eb
812387d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840ca86
812387d
 
 
 
 
f03b6eb
812387d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a30fabc
812387d
 
 
1dfe909
812387d
 
 
 
 
 
f03b6eb
812387d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dfe909
6867df0
812387d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
import torch.nn.functional as F
import json
import os
from sklearn.model_selection import train_test_split

# --- Config ---
MODEL_NAME = "bert-base-multilingual-cased"
TRAINING_FILE = "./training_data/greetings.txt"
SAVE_PATH = "./trained_bert_model"
EPOCHS = 3
BATCH_SIZE = 8
MAX_LEN = 64
LEARNING_RATE = 2e-5

# --- Load training data ---
def load_training_data(file_path):
    inputs, responses = [], []
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"{file_path} not found!")

    with open(file_path, "r", encoding="utf-8") as f:
        lines = [line.strip() for line in f if line.strip()]
    
    for i in range(0, len(lines), 2):
        user_input = lines[i].replace("User:", "").strip()
        assistant_response = lines[i+1].replace("Assistant:", "").strip()
        inputs.append(user_input)
        responses.append(assistant_response)
    
    return inputs, responses

# --- Dataset ---
class KiswahiliDataset(Dataset):
    def __init__(self, inputs, responses, tokenizer):
        self.inputs = inputs
        self.responses = responses
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        text = f"{self.inputs[idx]} [SEP] {self.responses[idx]}"
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=MAX_LEN,
            return_tensors='pt'
        )
        # Label 1 = positive example
        label = torch.tensor(1)
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': label
        }

# --- Main training ---
def main():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    inputs, responses = load_training_data(TRAINING_FILE)

    dataset = KiswahiliDataset(inputs, responses, tokenizer)
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

    model.train()
    for epoch in range(EPOCHS):
        total_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss/len(train_loader):.4f}")

    # Save model
    if not os.path.exists(SAVE_PATH):
        os.makedirs(SAVE_PATH)
    model.save_pretrained(SAVE_PATH)
    tokenizer.save_pretrained(SAVE_PATH)

    # Save responses for chatbot
    with open(os.path.join(SAVE_PATH, "responses.json"), "w", encoding="utf-8") as f:
        json.dump({"responses": responses}, f, ensure_ascii=False, indent=4)

    print(f"✅ Training complete. Model saved to {SAVE_PATH}")

if __name__ == "__main__":
    main()