vola2004 commited on
Commit
6f59ad0
·
verified ·
1 Parent(s): ea870e9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +137 -63
  2. predictor.py +310 -98
app.py CHANGED
@@ -1,76 +1,150 @@
1
- import gradio as gr
2
- import torch
3
  import os
 
 
 
 
4
  from predictor import load_models, predict_text
5
 
6
- # モデルの読み込み
7
- print("=== モデル読み込み開始 ===")
8
- try:
9
- load_models()
10
- print("✅ モデルの読み込みが完了しました")
11
- except Exception as e:
12
- print(f"❌ モデルの読み込みに失敗しました: {e}")
13
- print("⚠️ モデルファイルが存在しない可能性があります")
14
 
15
- def predict_age_gender(text):
16
- """年代・性別予測関数"""
17
- if not text.strip():
18
- return "テキストを入力してください。", "", ""
19
-
 
 
 
20
  try:
21
- result = predict_text(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # 年代予測結果を整形
24
- age_results = []
25
- for age, percentage in result['age_percentages'].items():
26
- age_results.append(f"{age}: {percentage}%")
27
- age_text = "\n".join(age_results)
28
 
29
- # 性別予測結果を整形
30
- gender_results = []
31
- for gender, percentage in result['gender_percentages'].items():
32
- gender_results.append(f"{gender}: {percentage}%")
33
- gender_text = "\n".join(gender_results)
 
34
 
35
- # 最も高い確率の年代を特定
36
- max_age = max(result['age_percentages'].items(), key=lambda x: x[1])
37
- max_gender = max(result['gender_percentages'].items(), key=lambda x: x[1])
38
 
39
- summary = f"推定結果: {max_age[0]} ({max_age[1]}%), {max_gender[0]} ({max_gender[1]}%)"
 
40
 
41
- return summary, age_text, gender_text
 
 
 
42
 
43
  except Exception as e:
44
- return f"エラーが発生しました: {str(e)}", "", ""
45
-
46
- # Gradioインターフェース
47
- interface = gr.Interface(
48
- fn=predict_age_gender,
49
- inputs=gr.Textbox(
50
- label="日本語テキストを入力してください",
51
- placeholder="例: 今日はとても良い天気ですね。友達と一緒に散歩をしました。",
52
- lines=3
53
- ),
54
- outputs=[
55
- gr.Textbox(label="推定結果サマリー"),
56
- gr.Textbox(label="年代予測詳細"),
57
- gr.Textbox(label="性別予測詳細")
58
- ],
59
- title="🧠 年代・性別推定システム",
60
- description="日本語テキストから年代と性別を推定するAIシステムです。",
61
- examples=[
62
- "今日はとても良い天気ですね。",
63
- "友達と一緒に散歩をしました。",
64
- "新しいスマートフォンを買いました。",
65
- "仕事が忙しくて疲れました。"
66
- ],
67
- theme=gr.themes.Soft()
68
- )
69
 
70
- # アプリケーション起動
71
- if __name__ == "__main__":
72
- interface.launch(
73
- server_name="0.0.0.0",
74
- server_port=7860,
75
- share=False # Hugging Face SpacesではFalse
76
- )
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ import sys
3
  import os
4
+ import gzip
5
+ import shutil
6
+ import warnings
7
+ from huggingface_hub import hf_hub_download
8
  from predictor import load_models, predict_text
9
 
10
+ # 警告を抑制
11
+ warnings.filterwarnings("ignore", category=UserWarning)
12
+ warnings.filterwarnings("ignore", message="Distant resource does not have an ETag")
 
 
 
 
 
13
 
14
+ # ETagチェックを無効にする環境変数を設定
15
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
16
+ os.environ['TRANSFORMERS_OFFLINE'] = '0'
17
+
18
+ app = Flask(__name__)
19
+
20
+ def download_and_extract_models():
21
+ """Hugging Face Hubからモデルをダウンロードして展開"""
22
  try:
23
+ print("=== モデルダウンロード開始 ===")
24
+
25
+ # リポジトリ名(あなたのユーザー名に変更してください)
26
+ repo_name = "vola2004/bert-japanese-age-gender" # ここを変更してください
27
+
28
+ # モデルファイルをダウンロード
29
+ model_files = [
30
+ "bert_age_model.bin",
31
+ "bert_gender_model.bin"
32
+ ]
33
+
34
+ for model_file in model_files:
35
+ if not os.path.exists(model_file):
36
+ print(f"📥 {model_file} をダウンロード中...")
37
+ try:
38
+ hf_hub_download(
39
+ repo_id=repo_name,
40
+ filename=model_file,
41
+ local_dir=".",
42
+ force_download=True,
43
+ resume_download=False
44
+ )
45
+ print(f"✅ {model_file} のダウンロードが完了しました")
46
+ except Exception as download_error:
47
+ print(f"❌ {model_file} のダウンロードに失敗: {download_error}")
48
+ return False
49
+ else:
50
+ print(f"✅ {model_file} は既に存在します")
51
+
52
+ return True
53
+
54
+ except Exception as e:
55
+ print(f"❌ ダウンロードエラー: {e}")
56
+ return False
57
+
58
+ # モデルのダウンロード、展開、読み込み
59
+ print("=== アプリケーション初期化 ===")
60
+
61
+ # モデルをダウンロード・展開
62
+ if download_and_extract_models():
63
+ # モデルの読み込み
64
+ print("=== モデル読み込み開始 ===")
65
+ try:
66
+ load_models()
67
+ print("✅ モデルの読み込みが完了しました")
68
+ except Exception as e:
69
+ print(f"❌ モデルの読み込みに失敗しました: {e}")
70
+ print("⚠️ モデルファイルが存在しない可能性があります")
71
+ print("⚠️ 予測機能は利用できません")
72
+ else:
73
+ print("❌ モデルのダウンロード・展開に失敗しました")
74
+ print("⚠️ 予測機能は利用できません")
75
+
76
+ @app.route('/', methods=['GET', 'POST'])
77
+ def index():
78
+ if request.method == 'POST':
79
+ # POSTリクエストの場合は予測処理を実行
80
+ try:
81
+ text = request.form.get('text', '')
82
+
83
+ print(f"=== 予測リクエスト受信 ===", flush=True)
84
+ print(f"入力テキスト: '{text}'", flush=True)
85
+ sys.stdout.flush()
86
+
87
+ if not text.strip():
88
+ print("⚠️ 空のテキストが入力されました", flush=True)
89
+ sys.stdout.flush()
90
+ return render_template('index.html',
91
+ input_text=text,
92
+ error='テキストを入力してください。')
93
+
94
+ # 予測実行
95
+ result = predict_text(text)
96
+
97
+ print(f"✅ 予測完了: {result}", flush=True)
98
+ sys.stdout.flush()
99
+
100
+ return render_template('index.html',
101
+ input_text=text,
102
+ result=result)
103
+
104
+ except Exception as e:
105
+ print(f"❌ 予測エラー: {e}", flush=True)
106
+ sys.stdout.flush()
107
+ return render_template('index.html',
108
+ input_text=text,
109
+ error=str(e))
110
+ else:
111
+ # GETリクエストの場合は通常のページ表示
112
+ return render_template('index.html')
113
+
114
+ @app.route('/predict', methods=['POST'])
115
+ def predict():
116
+ try:
117
+ data = request.get_json()
118
+ text = data.get('text', '')
119
 
120
+ print(f"=== 予測リクエスト受信 ===", flush=True)
121
+ print(f"入力テキスト: '{text}'", flush=True)
122
+ sys.stdout.flush()
 
 
123
 
124
+ if not text.strip():
125
+ print("⚠️ 空のテキストが入力されました", flush=True)
126
+ sys.stdout.flush()
127
+ return jsonify({
128
+ 'error': 'テキストを入力してください。'
129
+ })
130
 
131
+ # 予測実行
132
+ result = predict_text(text)
 
133
 
134
+ print(f" 予測完了: {result}", flush=True)
135
+ sys.stdout.flush()
136
 
137
+ return jsonify({
138
+ 'age_percentages': result['age_percentages'],
139
+ 'gender_percentages': result['gender_percentages']
140
+ })
141
 
142
  except Exception as e:
143
+ print(f" 予測エラー: {e}", flush=True)
144
+ sys.stdout.flush()
145
+ return jsonify({'error': str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ if __name__ == '__main__':
148
+ # 環境変数でデバッグモードを制御(デフォルトはFalse)
149
+ debug_mode = os.environ.get('FLASK_DEBUG', 'False').lower() == 'true'
150
+ app.run(debug=debug_mode, host='0.0.0.0', port=5002)
 
 
 
predictor.py CHANGED
@@ -2,11 +2,44 @@ import torch
2
  import torch.nn.functional as F
3
  import os
4
  from transformers import BertJapaneseTokenizer
 
 
 
 
5
 
6
  # 年代モデルと性別モデルの定義をインポート
7
- from SupervisedLearning import BertForAgeClassification, PRE_TRAINED_MODEL_NAME, DEVICE, NUM_AGE_CLASSIFIERS, AGE_CATEGORIES
8
  from GenderLearning import BertForGenderClassification, NUM_GENDER_LABELS
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # モデルファイルのパス
11
  AGE_MODEL_PATH = 'bert_age_model.bin'
12
  GENDER_MODEL_PATH = 'bert_gender_model.bin'
@@ -19,131 +52,310 @@ GENDER_CATEGORIES_JP = ["男性", "女性"]
19
  TOKENIZER = None
20
  AGE_MODEL = None
21
  GENDER_MODEL = None
 
22
 
23
- def load_models():
24
- """アプリケーション起動時にモデルを一度だけ読み込む"""
25
- global TOKENIZER, AGE_MODEL, GENDER_MODEL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # モデルファイルの存在確認
28
  if not os.path.exists(AGE_MODEL_PATH):
29
  raise FileNotFoundError(f"エラー: 年代学習済みモデル '{AGE_MODEL_PATH}' が見つかりません。")
30
 
31
- # 性別モデルはまだ学習されていない可能性があるので、警告のみ表示
32
- if not os.path.exists(GENDER_MODEL_PATH):
33
- print(f"警告: 性別学習済みモデル '{GENDER_MODEL_PATH}' が���つかりません。")
34
- print("性別予測は利用できません。年代予測のみ実行されます。")
35
-
36
- print("--- モデルの読み込みを開始します ---")
37
- TOKENIZER = BertJapaneseTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
38
 
39
- # 年代モデルの読み込み
40
- print(" 年代モデルを読み込み中...")
41
- AGE_MODEL = BertForAgeClassification(PRE_TRAINED_MODEL_NAME, NUM_AGE_CLASSIFIERS)
 
 
 
 
 
42
  try:
43
- if torch.__version__.startswith('1.'):
44
- AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE))
45
- else:
46
- AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE, weights_only=True))
 
 
 
47
  except Exception as e:
48
- print(f"年代モデルの読み込み中にエラーが発生しました: {e}")
49
- raise
50
- AGE_MODEL.to(DEVICE)
51
- AGE_MODEL.eval()
52
-
53
- # 性別モデルの読み込み(存在する場合のみ)
54
- if os.path.exists(GENDER_MODEL_PATH):
55
- print(" 性別モデルを読み込み中...")
56
- GENDER_MODEL = BertForGenderClassification(PRE_TRAINED_MODEL_NAME, NUM_GENDER_LABELS)
57
  try:
58
- if torch.__version__.startswith('1.'):
59
- GENDER_MODEL.load_state_dict(torch.load(GENDER_MODEL_PATH, map_location=DEVICE))
60
- else:
61
- GENDER_MODEL.load_state_dict(torch.load(GENDER_MODEL_PATH, map_location=DEVICE, weights_only=True))
62
- except Exception as e:
63
- print(f"性別モデルの読み込み中にエラーが発生しました: {e}")
64
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  GENDER_MODEL.to(DEVICE)
66
  GENDER_MODEL.eval()
67
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  GENDER_MODEL = None
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- print("--- モデルの読み込みが完了しました ---")
 
 
 
 
71
 
 
 
 
 
 
72
 
73
- def predict_text(text: str):
74
- """
75
- 入力されたテキストから「年代」と「性別」の各ラベルのパーセンテージを返す関数
76
- """
77
- if AGE_MODEL is None or TOKENIZER is None:
78
- load_models()
79
 
80
- print(f"DEBUG: 入力テキスト: '{text}'")
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # テキストの前処理
83
- encoding = TOKENIZER.encode_plus(
 
 
 
 
84
  text,
85
  add_special_tokens=True,
86
  max_length=128,
87
- return_token_type_ids=False,
88
  padding='max_length',
89
  truncation=True,
90
- return_attention_mask=True,
91
- return_tensors='pt',
92
  )
93
- input_ids = encoding['input_ids'].to(DEVICE)
94
- attention_mask = encoding['attention_mask'].to(DEVICE)
95
 
96
- print(f"DEBUG: input_ids shape: {input_ids.shape}")
97
- print(f"DEBUG: attention_mask shape: {attention_mask.shape}")
98
-
99
- # 年代の予測
 
100
  with torch.no_grad():
101
- _, age_logits = AGE_MODEL(input_ids=input_ids, attention_mask=attention_mask)
102
-
103
- print(f"DEBUG: age_logits shape: {age_logits.shape}")
104
- print(f"DEBUG: age_logits values: {age_logits}")
105
-
106
- # 各年代の二値分類の確率(シグモイド関数)
107
- age_probs = torch.sigmoid(age_logits)[0] # shape: (6,)
108
- print(f"DEBUG: age_probs shape: {age_probs.shape}")
109
- print(f"DEBUG: age_probs values: {age_probs}")
110
-
111
- # 年代の確率を辞書形式で保存
112
- age_percentages = {}
113
- for i, age in enumerate(AGE_CATEGORIES):
114
- percentage = float(f"{age_probs[i].item() * 100:.2f}") # 小数第2位まで
115
- age_percentages[age] = percentage
116
- print(f"DEBUG: {age}: {age_probs[i].item()} -> {percentage}%")
117
-
118
- # 性別の予測(モデルが存在する場合のみ)
119
- if GENDER_MODEL is not None:
120
- with torch.no_grad():
121
- _, gender_logits = GENDER_MODEL(input_ids=input_ids, attention_mask=attention_mask)
122
-
123
- print(f"DEBUG: gender_logits shape: {gender_logits.shape}")
124
- print(f"DEBUG: gender_logits values: {gender_logits}")
125
 
126
- # 性別の確率(Softmax関数)
127
- gender_probs = F.softmax(gender_logits, dim=1)[0] # shape: (2,)
128
- print(f"DEBUG: gender_probs shape: {gender_probs.shape}")
129
- print(f"DEBUG: gender_probs values: {gender_probs}")
 
 
130
 
131
- # 性別の確率を辞書形式で保存
132
- gender_percentages = {}
133
- for i, gender_jp in enumerate(GENDER_CATEGORIES_JP):
134
- percentage = float(f"{gender_probs[i].item() * 100:.2f}") # 小数第2位まで
135
- gender_percentages[gender_jp] = percentage
136
- print(f"DEBUG: {gender_jp}: {gender_probs[i].item()} -> {percentage}%")
137
- else:
138
- # 性別モデルが存在しない場合はデフォルト値を設定
139
- gender_percentages = {"男性": 50.0, "女性": 50.0}
140
- print("DEBUG: 性別モデルが存在しないため、デフォルト値を設定しました")
141
-
142
- # 結果を返す
143
- results = {
144
- "age_percentages": age_percentages,
145
- "gender_percentages": gender_percentages
 
 
146
  }
147
 
148
- print(f"DEBUG: 最終結果: {results}")
149
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn.functional as F
3
  import os
4
  from transformers import BertJapaneseTokenizer
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+ warnings.filterwarnings("ignore", message="Distant resource does not have an ETag")
8
+ warnings.filterwarnings("ignore", category=UserWarning)
9
 
10
  # 年代モデルと性別モデルの定義をインポート
11
+ from SupervisedLearning import BertForAgeClassification, DEVICE, NUM_AGE_CLASSIFIERS, AGE_CATEGORIES
12
  from GenderLearning import BertForGenderClassification, NUM_GENDER_LABELS
13
 
14
+ # 統合モデル用のクラス定義
15
+ class BertForClassification(torch.nn.Module):
16
+ """統合分類モデル(年代と性別を同時に分類)"""
17
+ def __init__(self, model_name, num_classes):
18
+ super().__init__()
19
+ from transformers import BertModel
20
+
21
+ if model_name is None:
22
+ self.bert = BertModel.from_pretrained('cl-tohoku/bert-large-japanese', use_safetensors=True)
23
+ else:
24
+ self.bert = BertModel.from_pretrained(model_name, use_safetensors=True)
25
+
26
+ self.dropout = torch.nn.Dropout(0.3)
27
+ self.classifier = torch.nn.Linear(self.bert.config.hidden_size, num_classes)
28
+
29
+ def forward(self, input_ids, attention_mask, labels=None):
30
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
31
+ pooled_output = outputs.pooler_output
32
+ pooled_output = self.dropout(pooled_output)
33
+
34
+ logits = self.classifier(pooled_output)
35
+
36
+ loss = None
37
+ if labels is not None:
38
+ loss_fct = torch.nn.CrossEntropyLoss()
39
+ loss = loss_fct(logits, labels)
40
+
41
+ return loss, logits
42
+
43
  # モデルファイルのパス
44
  AGE_MODEL_PATH = 'bert_age_model.bin'
45
  GENDER_MODEL_PATH = 'bert_gender_model.bin'
 
52
  TOKENIZER = None
53
  AGE_MODEL = None
54
  GENDER_MODEL = None
55
+ MODELS_LOADED = False # モデル読み込み状態を追跡
56
 
57
+ def load_shared_tokenizer():
58
+ """共有トークナイザーを読み込む"""
59
+ global TOKENIZER
60
+
61
+ if TOKENIZER is not None:
62
+ return # 既に読み込み済み
63
+
64
+ print("🔧 共有トークナイザーの読み込みを開始します...")
65
+
66
+ # 共有トークナイザー読み込み戦略(Largeモデル用)
67
+ tokenizer_loaded = False
68
+ tokenizer_models = ['cl-tohoku/bert-large-japanese', 'cl-tohoku/bert-base-japanese-v3']
69
+
70
+ for model_name in tokenizer_models:
71
+ if tokenizer_loaded:
72
+ break
73
+
74
+ print(f"共有トークナイザー読み込み試行: {model_name}")
75
+
76
+ # 戦略1: オンラインモード
77
+ try:
78
+ print(f"オンラインモードでトークナイザーを読み込み中... ({model_name})")
79
+ TOKENIZER = BertJapaneseTokenizer.from_pretrained(
80
+ model_name,
81
+ use_fast=False,
82
+ force_download=False,
83
+ resume_download=True
84
+ )
85
+ print(f"✅ オンラインモードでトークナイザーの読み込みが完了しました ({model_name})")
86
+ tokenizer_loaded = True
87
+
88
+ except Exception as e:
89
+ print(f"オンラインモード失敗 ({model_name}): {e}")
90
+
91
+ # 戦略2: オフラインモード
92
+ try:
93
+ print(f"オフラインモードでトークナイザーを読み込み中... ({model_name})")
94
+ os.environ['TRANSFORMERS_OFFLINE'] = '1'
95
+ TOKENIZER = BertJapaneseTokenizer.from_pretrained(
96
+ model_name,
97
+ local_files_only=True,
98
+ use_fast=False
99
+ )
100
+ print(f"✅ オフラインモードでトークナイザーの読み込みが完了しました ({model_name})")
101
+ tokenizer_loaded = True
102
+
103
+ except Exception as e2:
104
+ print(f"オフラインモード失敗 ({model_name}): {e2}")
105
+
106
+ if not tokenizer_loaded:
107
+ raise Exception("共有トークナイザーの読み込みに失敗しました")
108
+
109
+ def load_age_model():
110
+ """年代予測用モデルを読み込む"""
111
+ global TOKENIZER, AGE_MODEL
112
 
113
  # モデルファイルの存在確認
114
  if not os.path.exists(AGE_MODEL_PATH):
115
  raise FileNotFoundError(f"エラー: 年代学習済みモデル '{AGE_MODEL_PATH}' が見つかりません。")
116
 
117
+ print("--- 年代モデルの読み込みを開始します ---")
 
 
 
 
 
 
118
 
119
+ # 共有トークナイザーの読み込み
120
+ load_shared_tokenizer()
121
+
122
+ # 年代モデルの読み込み(bert-large-japanese を使用)
123
+ print("📊 年代モデルの読み込みを開始します...")
124
+ age_model_loaded = False
125
+
126
+ # 戦略1: bert-large-japanese で試行(学習時と同じ)
127
  try:
128
+ print("年代モデル: bert-large-japanese で試行中...")
129
+ AGE_MODEL = BertForAgeClassification('cl-tohoku/bert-large-japanese', NUM_AGE_CLASSIFIERS)
130
+ AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE))
131
+ AGE_MODEL.to(DEVICE)
132
+ AGE_MODEL.eval()
133
+ print("✅ 年代モデル(bert-large-japanese)の読み込みが完了しました")
134
+ age_model_loaded = True
135
  except Exception as e:
136
+ print(f"年代モデル(bert-large-japanese)失敗: {e}")
137
+
138
+ # 戦略2: bert-base-japanese-v3 で試行
 
 
 
 
 
 
139
  try:
140
+ print("年代モデル: bert-base-japanese-v3 で試行中...")
141
+ AGE_MODEL = BertForAgeClassification('cl-tohoku/bert-base-japanese-v3', NUM_AGE_CLASSIFIERS)
142
+ AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE))
143
+ AGE_MODEL.to(DEVICE)
144
+ AGE_MODEL.eval()
145
+ print(" 年代モデル(bert-base-japanese-v3)の読み込みが完了しました")
146
+ age_model_loaded = True
147
+ except Exception as e2:
148
+ print(f"年代モデル(bert-base-japanese-v3)も失敗: {e2}")
149
+
150
+ if not age_model_loaded:
151
+ raise Exception("年代モデルの読み込みに失敗しました")
152
+
153
+ print("🎉 年代モデルの読み込みが完了しました!")
154
+
155
+ def load_gender_model():
156
+ """性別予測用モデルを読み込む"""
157
+ global TOKENIZER, GENDER_MODEL
158
+
159
+ # 統合モデルファイルの存在確認
160
+ classification_model_path = 'bert_classification_model.bin'
161
+ if not os.path.exists(classification_model_path):
162
+ print(f"警告: 統合学習済みモデル '{classification_model_path}' が見つかりません。")
163
+ print("性別予測は利用できません。")
164
+ return
165
+
166
+ print("--- 性別モデルの読み込みを開始します(統合モデル使用) ---")
167
+
168
+ # 共有トークナイザーの読み込み
169
+ load_shared_tokenizer()
170
+
171
+ # 統合モデルの読み込み(bert-large-japanese を使用)
172
+ print("👥 性別モデルの読み込みを開始します(統合モデル)...")
173
+ gender_model_loaded = False
174
+
175
+ # 戦略1: bert-large-japanese で試行(統合モデルはLarge)
176
+ try:
177
+ print("性別モデル: bert-large-japanese で試行中(統合モデル)...")
178
+ GENDER_MODEL = BertForClassification('cl-tohoku/bert-large-japanese', 12) # 12クラス分類
179
+ GENDER_MODEL.load_state_dict(torch.load(classification_model_path, map_location=DEVICE))
180
  GENDER_MODEL.to(DEVICE)
181
  GENDER_MODEL.eval()
182
+ print("✅ 性別モデル(統合モデル)の読み込みが完了しました")
183
+ gender_model_loaded = True
184
+ except Exception as e:
185
+ print(f"性別モデル(統合モデル)失敗: {e}")
186
+
187
+ # 戦略2: bert-base-japanese-v3 で試行
188
+ try:
189
+ print("性別モデル: bert-base-japanese-v3 で試行中(統合モデル)...")
190
+ GENDER_MODEL = BertForClassification('cl-tohoku/bert-base-japanese-v3', 12) # 12クラス分類
191
+ GENDER_MODEL.load_state_dict(torch.load(classification_model_path, map_location=DEVICE))
192
+ GENDER_MODEL.to(DEVICE)
193
+ GENDER_MODEL.eval()
194
+ print("✅ 性別モデル(統合モデル)の読み込みが完了しました")
195
+ gender_model_loaded = True
196
+ except Exception as e2:
197
+ print(f"性別モデル(統合モデル)も失敗: {e2}")
198
+
199
+ if not gender_model_loaded:
200
+ print("❌ 性別モデルの読み込みに失敗しました")
201
+ print("⚠️ 性別予測は利用できません")
202
  GENDER_MODEL = None
203
+ else:
204
+ print("🎉 性別モデル(統合モデル)の読み込みが完了しました!")
205
+
206
+ def load_models():
207
+ """アプリケーション起動時にモデルを一度だけ読み込む(後方互換性のため)"""
208
+ global MODELS_LOADED
209
+
210
+ # 既に読み込み済みの場合はスキップ
211
+ if MODELS_LOADED:
212
+ print("✅ モデルは既に読み込み済みです。スキップします。")
213
+ return
214
 
215
+ # 年代モデルと性別モデルを個別に読み込み
216
+ load_age_model()
217
+ load_gender_model()
218
+
219
+ MODELS_LOADED = True # 読み込み完了フラグを設定
220
 
221
+ def predict_age(text):
222
+ """テキストから年代を予測(統合関数を使用)"""
223
+ # 統合予測を実行
224
+ full_result = predict_text(text)
225
+ return full_result['age_percentages']
226
 
227
+ def predict_gender(text):
228
+ """テキストから性別を予測(統合関数を使用)"""
229
+ # 統合予測を実行
230
+ full_result = predict_text(text)
231
+ return full_result['gender_percentages']
 
232
 
233
+ def predict_text(text):
234
+ """テキストから年代と性別を統合予測"""
235
+ global TOKENIZER, AGE_MODEL, GENDER_MODEL
236
+
237
+ # モデルが読み込まれていない場合は読み込み
238
+ if AGE_MODEL is None:
239
+ load_age_model()
240
+ if GENDER_MODEL is None:
241
+ load_gender_model()
242
+
243
+ if TOKENIZER is None or AGE_MODEL is None or GENDER_MODEL is None:
244
+ raise Exception("モデルが読み込まれていません。")
245
 
246
  # テキストの前処理
247
+ if not text or not text.strip():
248
+ raise ValueError("テキストが空です。")
249
+
250
+ # トークン化(一度だけ実行)
251
+ inputs = TOKENIZER.encode_plus(
252
  text,
253
  add_special_tokens=True,
254
  max_length=128,
 
255
  padding='max_length',
256
  truncation=True,
257
+ return_tensors='pt'
 
258
  )
 
 
259
 
260
+ # デバイスに移動
261
+ input_ids = inputs['input_ids'].to(DEVICE)
262
+ attention_mask = inputs['attention_mask'].to(DEVICE)
263
+
264
+ # 年代と性別を同時に予測
265
  with torch.no_grad():
266
+ # 年代予測
267
+ age_output = AGE_MODEL(input_ids, attention_mask)
268
+ if isinstance(age_output, tuple):
269
+ age_logits = age_output[1]
270
+ else:
271
+ age_logits = age_output
272
+ age_probs = torch.sigmoid(age_logits)
273
+ age_probs = age_probs.cpu().numpy().flatten()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
+ # 性別予測(統合モデルから性別部分を抽出)
276
+ classification_output = GENDER_MODEL(input_ids, attention_mask)
277
+ if isinstance(classification_output, tuple):
278
+ all_logits = classification_output[1]
279
+ else:
280
+ all_logits = classification_output
281
 
282
+ # 統合モデルの12クラス出力から性別部分(最後の2クラス)を抽出
283
+ gender_logits = all_logits[:, -2:]
284
+ gender_probs = torch.softmax(gender_logits, dim=1)
285
+ gender_probs = gender_probs.cpu().numpy().flatten()
286
+
287
+ # 結果の整形
288
+ age_result = {}
289
+ for i, category in enumerate(AGE_CATEGORIES):
290
+ age_result[category] = float(age_probs[i] * 100)
291
+
292
+ gender_result = {}
293
+ for i, category in enumerate(GENDER_CATEGORIES_JP):
294
+ gender_result[category] = float(gender_probs[i] * 100)
295
+
296
+ result = {
297
+ 'age_percentages': age_result,
298
+ 'gender_percentages': gender_result
299
  }
300
 
301
+ return result
302
+
303
+ def get_top_predictions(result, top_k=3):
304
+ """予測結果から上位k個を取得"""
305
+ # 年代の上位予測
306
+ age_sorted = sorted(result['age_percentages'].items(), key=lambda x: x[1], reverse=True)
307
+ top_ages = age_sorted[:top_k]
308
+
309
+ # 性別の上位予測
310
+ gender_sorted = sorted(result['gender_percentages'].items(), key=lambda x: x[1], reverse=True)
311
+ top_genders = gender_sorted[:top_k]
312
+
313
+ return {
314
+ 'top_ages': top_ages,
315
+ 'top_genders': top_genders
316
+ }
317
+
318
+ def format_prediction_result(result):
319
+ """予測結果を読みやすい形式に整形"""
320
+ formatted = "=== 予測結果 ===\n"
321
+
322
+ # 年代予測結果
323
+ formatted += "\n📊 年代予測:\n"
324
+ for age, percentage in sorted(result['age_percentages'].items(), key=lambda x: x[1], reverse=True):
325
+ formatted += f" {age}: {percentage:.1f}%\n"
326
+
327
+ # 性別予測結果
328
+ formatted += "\n👥 性別予測:\n"
329
+ for gender, percentage in sorted(result['gender_percentages'].items(), key=lambda x: x[1], reverse=True):
330
+ formatted += f" {gender}: {percentage:.1f}%\n"
331
+
332
+ return formatted
333
+
334
+ # テスト用の関数
335
+ def test_prediction():
336
+ """予測機能のテスト(統合予測)"""
337
+ test_text = "こんにちは、今日は良い天気ですね。"
338
+ print(f"テストテキスト: {test_text}")
339
+
340
+ try:
341
+ result = predict_text(test_text)
342
+ print(format_prediction_result(result))
343
+
344
+ # 個別予測もテスト
345
+ print("\n=== 個別予測テスト ===")
346
+ age_result = predict_age(test_text)
347
+ print(f"年代予測: {age_result}")
348
+
349
+ gender_result = predict_gender(test_text)
350
+ print(f"性別予測: {gender_result}")
351
+
352
+ except Exception as e:
353
+ print(f"予測エラー: {e}")
354
+
355
+ if __name__ == "__main__":
356
+ # モデルの読み込みテスト
357
+ try:
358
+ load_models()
359
+ test_prediction()
360
+ except Exception as e:
361
+ print(f"エラー: {e}")