Denis202 commited on
Commit
812387d
·
verified ·
1 Parent(s): fe24b96

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +96 -194
train_model.py CHANGED
@@ -1,203 +1,105 @@
1
  import torch
2
- from transformers import (
3
- AutoModelForSequenceClassification,
4
- AutoTokenizer,
5
- TrainingArguments,
6
- Trainer,
7
- DataCollatorWithPadding
8
- )
9
- from datasets import Dataset
10
- import os
11
- import glob
12
  import json
13
- from typing import List, Dict
14
- import numpy as np
15
-
16
- class KiswahiliBERTTrainer:
17
- def __init__(self, model_name="bert-base-multilingual-cased"):
18
- """
19
- BERT-based chatbot trainer for response selection
20
- """
21
- self.model_name = model_name
22
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
-
24
- # Initialize model
25
- self.model = AutoModelForSequenceClassification.from_pretrained(
26
- model_name,
27
- num_labels=2 # Binary classification: good response vs bad response
28
- )
29
-
30
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
- self.model.to(self.device)
32
- self.responses = [] # Store predefined responses
33
- print("✅ BERT Model and tokenizer loaded successfully!")
34
-
35
- def load_training_data(self, data_dir="./training_data"):
36
- """Load training data and create response selection pairs"""
37
- conversation_pairs = []
38
- text_files = glob.glob(os.path.join(data_dir, "*.txt"))
39
-
40
- if not text_files:
41
- print(f"⚠️ Hakuna faili za .txt katika {data_dir}")
42
- return []
43
-
44
- for file_path in text_files:
45
- print(f"📖 Inapakia data kutoka: {file_path}")
46
- try:
47
- with open(file_path, 'r', encoding='utf-8') as f:
48
- content = f.read().strip()
49
-
50
- # Split by conversation blocks
51
- blocks = [b.strip() for b in content.split('\n\n') if b.strip()]
52
-
53
- for block in blocks:
54
- lines = block.split('\n')
55
- user_input = None
56
- assistant_response = None
57
-
58
- for line in lines:
59
- if line.startswith('User:'):
60
- user_input = line.replace('User:', '').strip()
61
- elif line.startswith('Assistant:'):
62
- assistant_response = line.replace('Assistant:', '').strip()
63
-
64
- if user_input and assistant_response:
65
- # Store as positive example
66
- conversation_pairs.append({
67
- 'user_input': user_input,
68
- 'response': assistant_response,
69
- 'label': 1 # Positive example
70
- })
71
- # Also store the response for later use
72
- if assistant_response not in self.responses:
73
- self.responses.append(assistant_response)
74
-
75
- except Exception as e:
76
- print(f"❌ Hitilafu wakati wa kusoma {file_path}: {e}")
77
-
78
- print(f"📊 Imepakika jozi {len(conversation_pairs)} za mazungumzo")
79
- print(f"📝 Imepatikana majibu {len(self.responses)} ya kipekee")
80
- return conversation_pairs
81
 
82
- def create_training_pairs(self, conversation_pairs):
83
- """Create training data with positive and negative examples"""
84
- training_examples = []
85
-
86
- for pair in conversation_pairs:
87
- # Positive example
88
- training_examples.append({
89
- 'text': f"{pair['user_input']} [SEP] {pair['response']}",
90
- 'label': 1
91
- })
92
-
93
- # Create negative examples (random wrong responses)
94
- for _ in range(2): # 2 negative examples per positive
95
- if len(self.responses) > 1:
96
- wrong_responses = [r for r in self.responses if r != pair['response']]
97
- if wrong_responses:
98
- wrong_response = np.random.choice(wrong_responses)
99
- training_examples.append({
100
- 'text': f"{pair['user_input']} [SEP] {wrong_response}",
101
- 'label': 0
102
- })
103
-
104
- return training_examples
105
 
106
- def prepare_dataset(self, training_examples):
107
- """Tokenize the dataset for BERT"""
108
- texts = [example['text'] for example in training_examples]
109
- labels = [example['label'] for example in training_examples]
110
-
111
- # Tokenize all texts
112
- encodings = self.tokenizer(
113
- texts,
 
 
 
 
 
 
 
 
114
  truncation=True,
115
- padding=True,
116
- max_length=256,
117
- return_tensors="pt"
118
- )
119
-
120
- # Create dataset
121
- class CustomDataset(torch.utils.data.Dataset):
122
- def __init__(self, encodings, labels):
123
- self.encodings = encodings
124
- self.labels = labels
125
-
126
- def __getitem__(self, idx):
127
- item = {key: val[idx] for key, val in self.encodings.items()}
128
- item['labels'] = torch.tensor(self.labels[idx])
129
- return item
130
-
131
- def __len__(self):
132
- return len(self.labels)
133
-
134
- return CustomDataset(encodings, labels)
135
-
136
- def train(self, data_dir="./training_data", output_dir="./trained_bert_model"):
137
- """Train BERT for response selection"""
138
- conversation_pairs = self.load_training_data(data_dir=data_dir)
139
-
140
- if not conversation_pairs:
141
- print("❌ Hakuna data ya mafunzo! Tafadhali weka faili za .txt katika training_data/")
142
- return
143
-
144
- # Create training examples
145
- training_examples = self.create_training_pairs(conversation_pairs)
146
- dataset = self.prepare_dataset(training_examples)
147
-
148
- # BERT-specific training arguments
149
- training_args = TrainingArguments(
150
- output_dir=output_dir,
151
- num_train_epochs=5,
152
- per_device_train_batch_size=8,
153
- per_device_eval_batch_size=8,
154
- warmup_steps=100,
155
- weight_decay=0.01,
156
- logging_dir='./logs',
157
- logging_steps=10,
158
- evaluation_strategy="no",
159
- save_strategy="epoch",
160
- load_best_model_at_end=False,
161
- fp16=torch.cuda.is_available(),
162
- )
163
-
164
- data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
165
-
166
- trainer = Trainer(
167
- model=self.model,
168
- args=training_args,
169
- train_dataset=dataset,
170
- tokenizer=self.tokenizer,
171
- data_collator=data_collator,
172
  )
173
-
174
- print("🚀 Mafunzo ya BERT yanaanza...")
175
- trainer.train()
176
-
177
- print("💾 Inahifadhi modeli...")
178
- trainer.save_model(output_dir)
179
- self.tokenizer.save_pretrained(output_dir)
180
-
181
- # Save the response bank
182
- response_data = {
183
- 'responses': self.responses,
184
- 'model_type': 'bert-response-selector'
185
  }
186
- with open(os.path.join(output_dir, 'responses.json'), 'w', encoding='utf-8') as f:
187
- json.dump(response_data, f, ensure_ascii=False, indent=2)
188
-
189
- print(f"✅ Mafunzo ya BERT yamekamilika! Modeli imehifadhiwa katika {output_dir}")
190
- print(f"📋 Benki ya majibu imehifadhiwa na majibu {len(self.responses)}")
191
-
192
- # Alternative BERT models for Kiswahili
193
- BERT_MODEL_OPTIONS = {
194
- "multilingual": "bert-base-multilingual-cased",
195
- "afriberta": "castorini/afriberta_base",
196
- "xlm-roberta": "xlm-roberta-base",
197
- "swahili-bert": "ds4sd/ixbert-swahili-base", # Swahili-specific BERT
198
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  if __name__ == "__main__":
201
- # Example usage with a Swahili-optimized BERT model
202
- trainer = KiswahiliBERTTrainer(model_name=BERT_MODEL_OPTIONS["multilingual"])
203
- trainer.train()
 
1
  import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
4
+ import torch.nn.functional as F
 
 
 
 
 
 
 
5
  import json
6
+ import os
7
+ from sklearn.model_selection import train_test_split
8
+
9
+ # --- Config ---
10
+ MODEL_NAME = "bert-base-multilingual-cased"
11
+ TRAINING_FILE = "./training_data/greetings.txt"
12
+ SAVE_PATH = "./trained_bert_model"
13
+ EPOCHS = 3
14
+ BATCH_SIZE = 8
15
+ MAX_LEN = 64
16
+ LEARNING_RATE = 2e-5
17
+
18
+ # --- Load training data ---
19
+ def load_training_data(file_path):
20
+ inputs, responses = [], []
21
+ if not os.path.exists(file_path):
22
+ raise FileNotFoundError(f"{file_path} not found!")
23
+
24
+ with open(file_path, "r", encoding="utf-8") as f:
25
+ lines = [line.strip() for line in f if line.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ for i in range(0, len(lines), 2):
28
+ user_input = lines[i].replace("User:", "").strip()
29
+ assistant_response = lines[i+1].replace("Assistant:", "").strip()
30
+ inputs.append(user_input)
31
+ responses.append(assistant_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ return inputs, responses
34
+
35
+ # --- Dataset ---
36
+ class KiswahiliDataset(Dataset):
37
+ def __init__(self, inputs, responses, tokenizer):
38
+ self.inputs = inputs
39
+ self.responses = responses
40
+ self.tokenizer = tokenizer
41
+
42
+ def __len__(self):
43
+ return len(self.inputs)
44
+
45
+ def __getitem__(self, idx):
46
+ text = f"{self.inputs[idx]} [SEP] {self.responses[idx]}"
47
+ encoding = self.tokenizer(
48
+ text,
49
  truncation=True,
50
+ padding='max_length',
51
+ max_length=MAX_LEN,
52
+ return_tensors='pt'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
+ # Label 1 = positive example
55
+ label = torch.tensor(1)
56
+ return {
57
+ 'input_ids': encoding['input_ids'].squeeze(),
58
+ 'attention_mask': encoding['attention_mask'].squeeze(),
59
+ 'labels': label
 
 
 
 
 
 
60
  }
61
+
62
+ # --- Main training ---
63
+ def main():
64
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
65
+ inputs, responses = load_training_data(TRAINING_FILE)
66
+
67
+ dataset = KiswahiliDataset(inputs, responses, tokenizer)
68
+ train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
69
+
70
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+ model.to(device)
73
+
74
+ optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
75
+
76
+ model.train()
77
+ for epoch in range(EPOCHS):
78
+ total_loss = 0
79
+ for batch in train_loader:
80
+ optimizer.zero_grad()
81
+ input_ids = batch['input_ids'].to(device)
82
+ attention_mask = batch['attention_mask'].to(device)
83
+ labels = batch['labels'].to(device)
84
+
85
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
86
+ loss = outputs.loss
87
+ total_loss += loss.item()
88
+ loss.backward()
89
+ optimizer.step()
90
+ print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss/len(train_loader):.4f}")
91
+
92
+ # Save model
93
+ if not os.path.exists(SAVE_PATH):
94
+ os.makedirs(SAVE_PATH)
95
+ model.save_pretrained(SAVE_PATH)
96
+ tokenizer.save_pretrained(SAVE_PATH)
97
+
98
+ # Save responses for chatbot
99
+ with open(os.path.join(SAVE_PATH, "responses.json"), "w", encoding="utf-8") as f:
100
+ json.dump({"responses": responses}, f, ensure_ascii=False, indent=4)
101
+
102
+ print(f"✅ Training complete. Model saved to {SAVE_PATH}")
103
 
104
  if __name__ == "__main__":
105
+ main()