import os import pickle import numpy as np from sentence_transformers import SentenceTransformer from groq import Groq import gradio as gr # Gradioをインポート # --- 設定 --- # Groq APIキーをここに設定してください。 # 環境変数から読み込むことを推奨します: os.getenv("GROQ_API_KEY") GROQ_API_KEY = os.getenv("GROQ_API_KEY") VECTOR_DATA_DIR = "vector_data" # ベクトルデータが保存されているディレクトリ #SENTENCE_TRANSFORMER_MODEL = "cl-tohoku/bert-base-japanese-whole-word-masking" # ベクトル化に使用するモデル SENTENCE_TRANSFORMER_MODEL = "tohoku-nlp/bert-base-japanese-whole-word-masking" # ベクトル化に使用するモデル GROQ_MODEL_NAME = "llama3-8b-8192" # Groqで使用するモデル名 (例: "llama3-8b-8192", "mixtral-8x7b-32768") TOP_K_CHUNKS = 3 # 質問に対する関連チャンクを上位何個取得するか # --- グローバル変数 (モデルとデータベース) --- model = None vector_database = [] # [{"chunk": "テキスト", "vector": [float, ...]}, ...] def load_vector_database(vector_data_dir: str) -> list: """ 指定されたディレクトリからすべてのチャンクとベクトルデータを読み込みます。 ファイル名が '_chunks_with_vectors.pkl' で終わるものを対象とします。 Args: vector_data_dir (str): ベクトルデータが保存されているディレクトリのパス。 Returns: list: チャンクとベクトルデータのリスト。 """ db = [] if not os.path.isdir(vector_data_dir): print(f"エラー: ベクトルデータディレクトリ '{vector_data_dir}' が見つかりません。") return db print(f"ディレクトリ '{vector_data_dir}' からベクトルデータベースをロード中...") for file_name in os.listdir(vector_data_dir): if file_name.endswith('_chunks_with_vectors.pkl'): file_path = os.path.join(vector_data_dir, file_name) try: with open(file_path, 'rb') as f: data = pickle.load(f) db.extend(data) # リストを結合 print(f" '{file_name}' をロードしました。現在のチャンク数: {len(db)}") except Exception as e: print(f" エラー: '{file_name}' のロード中に問題が発生しました: {e}") continue print(f"ベクトルデータベースのロードが完了しました。合計 {len(db)} 個のチャンクをロードしました。") return db def initialize_models(): """ Sentence Transformerモデルをロードします。 """ global model print(f"Sentence Transformerモデル '{SENTENCE_TRANSFORMER_MODEL}' をロード中...") try: model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL) print("Sentence Transformerモデルのロードが完了しました。") except Exception as e: print(f"Sentence Transformerモデルのロードに失敗しました: {e}") print("必要なライブラリ (sentence-transformers, protobuf, fugashi, unidic-lite, ipadic) が") print("すべてインストールされているか確認してください。") # Gradioアプリは終了させず、エラーメッセージを表示する model = None # モデルがロードできなかったことを示す return f"エラー: Sentence Transformerモデルのロードに失敗しました: {e}" def get_most_similar_chunks(query_vector: np.ndarray, db: list, top_k: int = 3) -> list: """ クエリベクトルに最も類似したチャンクをデータベースから取得します。 Args: query_vector (np.ndarray): 質問のベクトル。 db (list): チャンクとベクトルを含むデータベース。 top_k (int): 取得する類似チャンクの数。 Returns: list: 最も類似したチャンクのテキストのリスト。 """ if not db: return [] similarities = [] for item in db: chunk_vector = np.array(item["vector"]) # リストからNumPy配列に変換 # コサイン類似度を計算 # A・B / (||A|| * ||B||) similarity = np.dot(query_vector, chunk_vector) / (np.linalg.norm(query_vector) * np.linalg.norm(chunk_vector)) similarities.append((similarity, item["chunk"])) # 類似度で降順にソートし、上位K個を取得 similarities.sort(key=lambda x: x[0], reverse=True) # 類似度0以上のチャンクのみを返す(関連性のないものは含めない) relevant_chunks = [chunk for sim, chunk in similarities if sim > 0][:top_k] return relevant_chunks def generate_groq_response(query: str, context_chunks: list) -> str: """ Groq APIを使用して回答を生成します。 Args: query (str): ユーザーの質問。 context_chunks (list): 質問に関連するチャンクのテキストリスト。 Returns: str: Groqが生成した回答。 """ if not GROQ_API_KEY or GROQ_API_KEY == "YOUR_GROQ_API_KEY" or GROQ_API_KEY == "": return "エラー: Groq APIキーが設定されていません。環境変数 'GROQ_API_KEY' を設定してください。" try: client = Groq(api_key=GROQ_API_KEY) # プロンプトの構築 (RAG形式) system_prompt = "あなたは役立つAIアシスタントです。提供された情報に基づいてユーザーの質問に答えてください。提供された情報で答えられない場合は、その旨を伝えてください。" context_text = "" if context_chunks: context_text = "\n\n関連情報:\n" + "\n".join([f"- {c}" for c in context_chunks]) user_message = f"質問: {query}{context_text}\n\n上記の質問に、提供された関連情報のみに基づいて答えてください。もし情報が不足している場合は、「提供された情報からは答えられません。」と明確に述べてください。" chat_completion = client.chat.completions.create( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}, ], model=GROQ_MODEL_NAME, # 設定で指定されたモデル名を使用 temperature=0.7, # 創造性を調整 (0.0-1.0) max_tokens=500, # 生成する回答の最大トークン数 ) return chat_completion.choices[0].message.content except Exception as e: return f"Groq API呼び出し中にエラーが発生しました: {e}" def rag_chatbot(user_query: str) -> str: """ Gradioインターフェースのメイン関数。 ユーザーの質問を受け取り、RAGプロセスを実行して回答を返します。 """ if model is None: return "エラー: Sentence Transformerモデルがロードされていません。アプリケーションの初期化を確認してください。" if not vector_database: return "警告: ベクトルデータベースが空です。質問に適切に答えられない可能性があります。txt2vec_multi.py を実行して、先にベクトルデータを作成してください。" if not user_query.strip(): return "質問を入力してください。" try: # 質問をベクトル化 query_vector = model.encode(user_query) # 類似チャンクを検索 relevant_chunks = get_most_similar_chunks(query_vector, vector_database, TOP_K_CHUNKS) # Groqで回答を生成 groq_response = generate_groq_response(user_query, relevant_chunks) return groq_response except Exception as e: return f"エラーが発生しました: {e}" # --- アプリケーションの初期化 --- # Gradioアプリが起動する前にモデルとデータベースをロード print("アプリケーションを初期化中...") initialize_models() vector_database = load_vector_database(VECTOR_DATA_DIR) print("アプリケーションの初期化が完了しました。") # --- Gradioインターフェースの定義と起動 --- iface = gr.Interface( fn=rag_chatbot, inputs=gr.Textbox(lines=2, label="質問を入力してください", placeholder="例:この文書の主要なテーマは何ですか?"), outputs=gr.Textbox(label="ボットの回答"), title="四国電力 RAG チャットボット", description="四国電力について質問を入力すると回答します。", allow_flagging="never" # フラグ機能を無効にする ) if __name__ == "__main__": # Gradioアプリを起動 iface.launch(share=True) # 公開リンクを作成