File size: 2,422 Bytes
23ec8a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
import torch
import subprocess
def finetune(model_name, hf_token, upload_repo):
os.environ["HF_TOKEN"] = hf_token
# トークナイザとモデル準備
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=hf_token)
# データセット読み込み(日本語チャット)
dataset = load_dataset("rinna/llm-japanese-dataset-v1", split="train")
# 前処理
def tokenize_fn(example):
return tokenizer(example["text"], truncation=True, max_length=512)
tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names)
# データコラレータ
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# トレーニング設定
training_args = TrainingArguments(
output_dir="./finetuned_model",
per_device_train_batch_size=2,
num_train_epochs=1,
save_total_limit=1,
logging_steps=10,
push_to_hub=True,
hub_model_id=upload_repo,
hub_token=hf_token
)
# Trainerセットアップ
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator
)
# 学習実行
trainer.train()
# モデルをHugging Face Hubへアップロード
trainer.push_to_hub()
return f"ファインチューニング完了!モデルは https://huggingface.co/{upload_repo} にアップロードされました。"
# Gradioインターフェース
with gr.Blocks() as demo:
gr.Markdown("# 日本語チャットモデル 簡易ファインチューニング")
model_name = gr.Textbox(label="元モデル名(例:rinna/japanese-gpt-neox-3.6b)")
hf_token = gr.Textbox(label="Hugging Face トークン", type="password")
upload_repo = gr.Textbox(label="アップロード先リポジトリ名(例:yourname/finetuned-chat-jp)")
start_btn = gr.Button("ファインチューニング開始")
output = gr.Textbox(label="実行結果")
start_btn.click(finetune, inputs=[model_name, hf_token, upload_repo], outputs=output)
demo.launch() |