drosshopper commited on
Commit
06204cb
·
verified ·
1 Parent(s): 664eae1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -2,16 +2,18 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
- # 🧠 モデルロード(初回キャッシュで高速化)
6
  @st.cache_resource
7
  def load_model():
8
- tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b")
 
 
9
  model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b")
10
  return tokenizer, model
11
 
12
  tokenizer, model = load_model()
13
 
14
- # 📝 入力データ(4ケース)
15
  cases = [
16
  {
17
  "user_input": """1歳時馬体重:430kg
@@ -60,31 +62,31 @@ cases = [
60
  ]
61
 
62
  # ✨ コメント生成関数
63
- def generate_comment(prompt):
64
- full_prompt = f"""以下は、一口馬主の募集馬についてのAI総括コメントです。
65
 
66
  ユーザーが入力した項目:
67
- {prompt["user_input"]}
68
 
69
  予測した結果:
70
- {prompt["result_summary"]}
71
 
72
  この馬について得られた結果を総括したうえで、ポジティブな言葉で締めくくってください。
73
  コメント:"""
74
 
75
- input_ids = tokenizer.encode(full_prompt, return_tensors="pt")
76
  output = model.generate(input_ids, max_new_tokens=120, do_sample=True, temperature=0.7)
77
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
78
  return decoded.split("コメント:")[-1].strip()
79
 
80
- # 🌐 Streamlit UI
81
- st.title("🐴 募集馬AIコメント生成(デモ)")
82
 
83
  if st.button("🎯 コメントを一括生成"):
84
  for i, case in enumerate(cases, 1):
85
  with st.spinner(f"Case {i} を生成中..."):
86
  comment = generate_comment(case)
87
- st.markdown(f"### 🐴 Case {i}")
88
  st.markdown(f"**📥 入力**\n```\n{case['user_input']}\n```\n")
89
  st.markdown(f"**📊 予測結果**\n```\n{case['result_summary']}\n```\n")
90
  st.markdown(f"**📝 コメント**\n> {comment}")
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
+ # 🔃 モデル・トークナイザーの読み込み(slow tokenizer指定)
6
  @st.cache_resource
7
  def load_model():
8
+ tokenizer = AutoTokenizer.from_pretrained(
9
+ "rinna/japanese-gpt-neox-3.6b", use_fast=False
10
+ )
11
  model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b")
12
  return tokenizer, model
13
 
14
  tokenizer, model = load_model()
15
 
16
+ # 📋 サンプルデータ(4頭分)
17
  cases = [
18
  {
19
  "user_input": """1歳時馬体重:430kg
 
62
  ]
63
 
64
  # ✨ コメント生成関数
65
+ def generate_comment(case):
66
+ prompt = f"""以下は、一口馬主の募集馬についてのAI総括コメントです。
67
 
68
  ユーザーが入力した項目:
69
+ {case["user_input"]}
70
 
71
  予測した結果:
72
+ {case["result_summary"]}
73
 
74
  この馬について得られた結果を総括したうえで、ポジティブな言葉で締めくくってください。
75
  コメント:"""
76
 
77
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
78
  output = model.generate(input_ids, max_new_tokens=120, do_sample=True, temperature=0.7)
79
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
80
  return decoded.split("コメント:")[-1].strip()
81
 
82
+ # 🖼️ UI部分
83
+ st.title("🐴 AIによる募集馬コメント生成デモ")
84
 
85
  if st.button("🎯 コメントを一括生成"):
86
  for i, case in enumerate(cases, 1):
87
  with st.spinner(f"Case {i} を生成中..."):
88
  comment = generate_comment(case)
89
+ st.markdown(f"### 🐎 Case {i}")
90
  st.markdown(f"**📥 入力**\n```\n{case['user_input']}\n```\n")
91
  st.markdown(f"**📊 予測結果**\n```\n{case['result_summary']}\n```\n")
92
  st.markdown(f"**📝 コメント**\n> {comment}")