File size: 2,422 Bytes
23ec8a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
import torch
import subprocess

def finetune(model_name, hf_token, upload_repo):
    os.environ["HF_TOKEN"] = hf_token

    # トークナイザとモデル準備
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=hf_token)

    # データセット読み込み(日本語チャット)
    dataset = load_dataset("rinna/llm-japanese-dataset-v1", split="train")

    # 前処理
    def tokenize_fn(example):
        return tokenizer(example["text"], truncation=True, max_length=512)

    tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names)

    # データコラレータ
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    # トレーニング設定
    training_args = TrainingArguments(
        output_dir="./finetuned_model",
        per_device_train_batch_size=2,
        num_train_epochs=1,
        save_total_limit=1,
        logging_steps=10,
        push_to_hub=True,
        hub_model_id=upload_repo,
        hub_token=hf_token
    )

    # Trainerセットアップ
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator
    )

    # 学習実行
    trainer.train()

    # モデルをHugging Face Hubへアップロード
    trainer.push_to_hub()

    return f"ファインチューニング完了!モデルは https://huggingface.co/{upload_repo} にアップロードされました。"

# Gradioインターフェース
with gr.Blocks() as demo:
    gr.Markdown("# 日本語チャットモデル 簡易ファインチューニング")

    model_name = gr.Textbox(label="元モデル名(例:rinna/japanese-gpt-neox-3.6b)")
    hf_token = gr.Textbox(label="Hugging Face トークン", type="password")
    upload_repo = gr.Textbox(label="アップロード先リポジトリ名(例:yourname/finetuned-chat-jp)")

    start_btn = gr.Button("ファインチューニング開始")
    output = gr.Textbox(label="実行結果")

    start_btn.click(finetune, inputs=[model_name, hf_token, upload_repo], outputs=output)

demo.launch()