Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,16 +2,18 @@ import streamlit as st
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import torch
|
4 |
|
5 |
-
#
|
6 |
@st.cache_resource
|
7 |
def load_model():
|
8 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
9 |
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b")
|
10 |
return tokenizer, model
|
11 |
|
12 |
tokenizer, model = load_model()
|
13 |
|
14 |
-
#
|
15 |
cases = [
|
16 |
{
|
17 |
"user_input": """1歳時馬体重:430kg
|
@@ -60,31 +62,31 @@ cases = [
|
|
60 |
]
|
61 |
|
62 |
# ✨ コメント生成関数
|
63 |
-
def generate_comment(
|
64 |
-
|
65 |
|
66 |
ユーザーが入力した項目:
|
67 |
-
{
|
68 |
|
69 |
予測した結果:
|
70 |
-
{
|
71 |
|
72 |
この馬について得られた結果を総括したうえで、ポジティブな言葉で締めくくってください。
|
73 |
コメント:"""
|
74 |
|
75 |
-
input_ids = tokenizer.encode(
|
76 |
output = model.generate(input_ids, max_new_tokens=120, do_sample=True, temperature=0.7)
|
77 |
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
78 |
return decoded.split("コメント:")[-1].strip()
|
79 |
|
80 |
-
#
|
81 |
-
st.title("🐴
|
82 |
|
83 |
if st.button("🎯 コメントを一括生成"):
|
84 |
for i, case in enumerate(cases, 1):
|
85 |
with st.spinner(f"Case {i} を生成中..."):
|
86 |
comment = generate_comment(case)
|
87 |
-
st.markdown(f"###
|
88 |
st.markdown(f"**📥 入力**\n```\n{case['user_input']}\n```\n")
|
89 |
st.markdown(f"**📊 予測結果**\n```\n{case['result_summary']}\n```\n")
|
90 |
st.markdown(f"**📝 コメント**\n> {comment}")
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import torch
|
4 |
|
5 |
+
# 🔃 モデル・トークナイザーの読み込み(slow tokenizer指定)
|
6 |
@st.cache_resource
|
7 |
def load_model():
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
9 |
+
"rinna/japanese-gpt-neox-3.6b", use_fast=False
|
10 |
+
)
|
11 |
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b")
|
12 |
return tokenizer, model
|
13 |
|
14 |
tokenizer, model = load_model()
|
15 |
|
16 |
+
# 📋 サンプルデータ(4頭分)
|
17 |
cases = [
|
18 |
{
|
19 |
"user_input": """1歳時馬体重:430kg
|
|
|
62 |
]
|
63 |
|
64 |
# ✨ コメント生成関数
|
65 |
+
def generate_comment(case):
|
66 |
+
prompt = f"""以下は、一口馬主の募集馬についてのAI総括コメントです。
|
67 |
|
68 |
ユーザーが入力した項目:
|
69 |
+
{case["user_input"]}
|
70 |
|
71 |
予測した結果:
|
72 |
+
{case["result_summary"]}
|
73 |
|
74 |
この馬について得られた結果を総括したうえで、ポジティブな言葉で締めくくってください。
|
75 |
コメント:"""
|
76 |
|
77 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
78 |
output = model.generate(input_ids, max_new_tokens=120, do_sample=True, temperature=0.7)
|
79 |
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
80 |
return decoded.split("コメント:")[-1].strip()
|
81 |
|
82 |
+
# 🖼️ UI部分
|
83 |
+
st.title("🐴 AIによる募集馬コメント生成デモ")
|
84 |
|
85 |
if st.button("🎯 コメントを一括生成"):
|
86 |
for i, case in enumerate(cases, 1):
|
87 |
with st.spinner(f"Case {i} を生成中..."):
|
88 |
comment = generate_comment(case)
|
89 |
+
st.markdown(f"### 🐎 Case {i}")
|
90 |
st.markdown(f"**📥 入力**\n```\n{case['user_input']}\n```\n")
|
91 |
st.markdown(f"**📊 予測結果**\n```\n{case['result_summary']}\n```\n")
|
92 |
st.markdown(f"**📝 コメント**\n> {comment}")
|