import torch from torch import nn from torch.utils.data import Dataset, DataLoader, RandomSampler from torch.optim import AdamW from transformers import BertJapaneseTokenizer, BertModel from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score import pandas as pd from tqdm import tqdm import os from DataNLP import load_preprocessed_data # --- スクリプトのディレクトリを基準にパスを設定 --- script_dir = os.path.dirname(os.path.abspath(__file__)) os.chdir(script_dir) # 設定 PRE_TRAINED_MODEL_NAME = 'cl-tohoku/bert-large-japanese' MAX_LEN = 128 BATCH_SIZE = 32 # バッチサイズを増加して高速化 EPOCHS = 10 # 重みを大幅に更新するためエポック数を増加 DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") NUM_GENDER_LABELS = 2 # male, female # --- データセットクラス --- class GenderDataset(Dataset): def __init__(self, texts, gender_labels, tokenizer, max_len): self.texts = texts self.gender_labels = gender_labels self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.texts) def __getitem__(self, item): text = str(self.texts[item]) gender_label = self.gender_labels[item] encoding = self.tokenizer.encode_plus( text, add_special_tokens=True, max_length=self.max_len, return_token_type_ids=False, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt', ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'gender_labels': torch.tensor(int(gender_label), dtype=torch.long) } # --- モデル定義 --- class BertForGenderClassification(nn.Module): def __init__(self, model_name, num_gender_labels): super().__init__() self.bert = BertModel.from_pretrained(model_name, use_safetensors=True) self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) self.gender_classifier = nn.Linear(self.bert.config.hidden_size, num_gender_labels) def forward(self, input_ids, attention_mask, gender_labels=None): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.pooler_output pooled_output = self.dropout(pooled_output) gender_logits = self.gender_classifier(pooled_output) loss = None if gender_labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(gender_logits, gender_labels) return loss, gender_logits # --- 学習関数 --- def train_epoch(model, data_loader, optimizer, device): model.train() total_loss = 0 for batch in tqdm(data_loader, desc="Training"): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) gender_labels = batch['gender_labels'].to(device) optimizer.zero_grad() loss, _ = model(input_ids=input_ids, attention_mask=attention_mask, gender_labels=gender_labels) if isinstance(loss, torch.Tensor): loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(data_loader) # --- 評価関数 --- def eval_model(model, data_loader, device): model.eval() gender_preds, gender_true_labels = [], [] with torch.no_grad(): for batch in tqdm(data_loader, desc="Evaluating"): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) _, gender_logits = model(input_ids=input_ids, attention_mask=attention_mask) gender_preds.extend(torch.argmax(gender_logits, dim=1).cpu().numpy()) gender_true_labels.extend(batch['gender_labels'].cpu().numpy()) gender_acc = accuracy_score(gender_true_labels, gender_preds) return gender_acc # --- データサンプリング関数(性別ごとにバランシング) --- def sample_balanced_data(df, max_per_gender=20000): """ 性別ごとにバランシングする - 性別:各性別ごとに最大max_per_gender件 """ gender_sampled_dfs = [] for gender_label in df['性別_label'].unique(): subset = df[df['性別_label'] == gender_label] if len(subset) > max_per_gender: subset = subset.sample(max_per_gender, random_state=42) gender_sampled_dfs.append(subset) return pd.concat(gender_sampled_dfs).sample(frac=1, random_state=42).reset_index(drop=True) # --- メイン処理 --- def main(): print("--- 1. データ読み込み ---") df, _ = load_preprocessed_data() # --- データを性別でバランシングして軽量化 --- df = sample_balanced_data(df, max_per_gender=5000) # データ量を大幅に増加 # ラベルの分布を確認 print("\n性別ラベルの分布:") print(df['性別_label'].value_counts().sort_index()) print(f"\n合計データ数: {len(df)} 件") # 訓練用と検証用に分割 train_df, val_df = train_test_split(df, test_size=0.2, random_state=42) print(f"\n--- 2. トークナイザとデータローダーの準備 ---") tokenizer = BertJapaneseTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME) train_dataset = GenderDataset( train_df['text'].values, train_df['性別_label'].values, tokenizer, MAX_LEN ) train_sampler = RandomSampler(train_dataset) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler) val_dataset = GenderDataset( val_df['text'].values, val_df['性別_label'].values, tokenizer, MAX_LEN ) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE) print("\n--- 3. モデルのセットアップ ---") model = BertForGenderClassification(PRE_TRAINED_MODEL_NAME, NUM_GENDER_LABELS) model.to(DEVICE) # BERT全体をファインチューニング(レイヤーごとに異なる学習率を設定) optimizer = AdamW([ {'params': model.bert.parameters(), 'lr': 2e-5}, # BERT本体は小さい学習率 {'params': model.gender_classifier.parameters(), 'lr': 5e-4}, # 分類層は大きい学習率 ]) print("\n--- 4. 学習開始 ---") print(f"デバイス: {DEVICE}") print(f"訓練データ数: {len(train_df)} 件") print(f"検証データ数: {len(val_df)} 件") print(f"バッチサイズ: {BATCH_SIZE}") print(f"エポック数: {EPOCHS}") print(f"推定学習時間: 約35時間") import time start_time = time.time() for epoch in range(EPOCHS): epoch_start_time = time.time() print(f"\n{'='*60}") print(f"Epoch {epoch + 1}/{EPOCHS} 開始") print(f"{'='*60}") train_loss = train_epoch(model, train_loader, optimizer, DEVICE) print(f"Train Loss (Gender): {train_loss:.4f}") gender_acc = eval_model(model, val_loader, DEVICE) print(f"Gender Validation Accuracy: {gender_acc:.4f} ({gender_acc*100:.2f}%)") # エポックの経過時間を表示 epoch_time = time.time() - epoch_start_time elapsed_time = time.time() - start_time remaining_epochs = EPOCHS - (epoch + 1) estimated_remaining_time = (elapsed_time / (epoch + 1)) * remaining_epochs print(f"\nエポック所要時間: {epoch_time/60:.1f}分") print(f"経過時間: {elapsed_time/3600:.1f}時間") print(f"推定残り時間: {estimated_remaining_time/3600:.1f}時間") print(f"{'='*60}") print("\n--- 5. 学習完了 ---") torch.save(model.state_dict(), 'bert_gender_model.bin') print("モデルを 'bert_gender_model.bin' に保存しました。") if __name__ == '__main__': main()