diff --git a/.env.example b/.env.example
deleted file mode 100644
index 1a6701268f8f2670b54b2a39d761679ae246b509..0000000000000000000000000000000000000000
--- a/.env.example
+++ /dev/null
@@ -1,6 +0,0 @@
-SYNTHESIZER_MODEL=
-SYNTHESIZER_BASE_URL=
-SYNTHESIZER_API_KEY=
-TRAINEE_MODEL=
-TRAINEE_BASE_URL=
-TRAINEE_API_KEY=
diff --git a/.gitattributes b/.gitattributes
deleted file mode 100644
index a6344aac8c09253b3b630fb776ae94478aa0275b..0000000000000000000000000000000000000000
--- a/.gitattributes
+++ /dev/null
@@ -1,35 +0,0 @@
-*.7z filter=lfs diff=lfs merge=lfs -text
-*.arrow filter=lfs diff=lfs merge=lfs -text
-*.bin filter=lfs diff=lfs merge=lfs -text
-*.bz2 filter=lfs diff=lfs merge=lfs -text
-*.ckpt filter=lfs diff=lfs merge=lfs -text
-*.ftz filter=lfs diff=lfs merge=lfs -text
-*.gz filter=lfs diff=lfs merge=lfs -text
-*.h5 filter=lfs diff=lfs merge=lfs -text
-*.joblib filter=lfs diff=lfs merge=lfs -text
-*.lfs.* filter=lfs diff=lfs merge=lfs -text
-*.mlmodel filter=lfs diff=lfs merge=lfs -text
-*.model filter=lfs diff=lfs merge=lfs -text
-*.msgpack filter=lfs diff=lfs merge=lfs -text
-*.npy filter=lfs diff=lfs merge=lfs -text
-*.npz filter=lfs diff=lfs merge=lfs -text
-*.onnx filter=lfs diff=lfs merge=lfs -text
-*.ot filter=lfs diff=lfs merge=lfs -text
-*.parquet filter=lfs diff=lfs merge=lfs -text
-*.pb filter=lfs diff=lfs merge=lfs -text
-*.pickle filter=lfs diff=lfs merge=lfs -text
-*.pkl filter=lfs diff=lfs merge=lfs -text
-*.pt filter=lfs diff=lfs merge=lfs -text
-*.pth filter=lfs diff=lfs merge=lfs -text
-*.rar filter=lfs diff=lfs merge=lfs -text
-*.safetensors filter=lfs diff=lfs merge=lfs -text
-saved_model/**/* filter=lfs diff=lfs merge=lfs -text
-*.tar.* filter=lfs diff=lfs merge=lfs -text
-*.tar filter=lfs diff=lfs merge=lfs -text
-*.tflite filter=lfs diff=lfs merge=lfs -text
-*.tgz filter=lfs diff=lfs merge=lfs -text
-*.wasm filter=lfs diff=lfs merge=lfs -text
-*.xz filter=lfs diff=lfs merge=lfs -text
-*.zip filter=lfs diff=lfs merge=lfs -text
-*.zst filter=lfs diff=lfs merge=lfs -text
-*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
deleted file mode 100644
index 678cdc50b0dbb52dba8ba9306f6db3efa556df13..0000000000000000000000000000000000000000
--- a/.gitignore
+++ /dev/null
@@ -1,179 +0,0 @@
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-cover/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-.pybuilder/
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-# For a library or package, you might want to ignore these files since the code is
-# intended to run in multiple environments; otherwise, check them in:
-# .python-version
-
-# pipenv
-# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
-# However, in case of collaboration, if having platform-specific dependencies or dependencies
-# having no cross-platform support, pipenv may install dependencies that don't work, or not
-# install all needed dependencies.
-#Pipfile.lock
-
-# UV
-# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
-# This is especially recommended for binary packages to ensure reproducibility, and is more
-# commonly ignored for libraries.
-#uv.lock
-
-# poetry
-# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
-# This is especially recommended for binary packages to ensure reproducibility, and is more
-# commonly ignored for libraries.
-# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
-#poetry.lock
-
-# pdm
-# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
-#pdm.lock
-# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
-# in version control.
-# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
-.pdm.toml
-.pdm-python
-.pdm-build/
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-# pytype static type analyzer
-.pytype/
-
-# Cython debug symbols
-cython_debug/
-
-# PyCharm
-# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
-# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
-# and can be added to the global gitignore or merged into this file. For a more nuclear
-# option (not recommended) you can uncomment the following to ignore the entire idea folder.
-.idea/
-
-# Ruff stuff:
-.ruff_cache/
-
-# PyPI configuration file
-.pypirc
-
-cache
-*.pyc
-*.html
-.gradio
diff --git a/README.md b/README.md
deleted file mode 100644
index a5b0bc76da4083dce4f5710c7172498526f1ddb7..0000000000000000000000000000000000000000
--- a/README.md
+++ /dev/null
@@ -1,14 +0,0 @@
----
-title: GraphGen
-emoji: 🐠
-colorFrom: gray
-colorTo: blue
-sdk: gradio
-sdk_version: 5.32.1
-app_file: app.py
-pinned: false
-license: apache-2.0
-short_description: A framework for synthetic data generation based on KG.
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
index 450a48828bd383547f5a584e00d7ea0037913cb7..10a914cefb4727756c1d93ec98fb20e3fbf0c081 100644
--- a/app.py
+++ b/app.py
@@ -1,20 +1,19 @@
+import json
import os
import sys
-import json
import tempfile
-import pandas as pd
import gradio as gr
-
-from gradio_i18n import Translate, gettext as _
-
-from webui.base import GraphGenParams
-from webui.test_api import test_api_connection
-from webui.cache_utils import setup_workspace, cleanup_workspace
-from webui.count_tokens import count_tokens
+import pandas as pd
+from base import GraphGenParams
+from cache_utils import cleanup_workspace, setup_workspace
+from count_tokens import count_tokens
+from gradio_i18n import Translate
+from gradio_i18n import gettext as _
+from test_api import test_api_connection
# pylint: disable=wrong-import-position
-root_dir = os.path.dirname(os.path.abspath(__file__))
+root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_dir)
from graphgen.graphgen import GraphGen
@@ -22,7 +21,6 @@ from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
from graphgen.models.llm.limitter import RPM, TPM
from graphgen.utils import set_logger
-
css = """
.center-row {
display: flex;
@@ -37,9 +35,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
set_logger(log_file, if_stream=False)
- graph_gen = GraphGen(
- working_dir=working_dir
- )
+ graph_gen = GraphGen(working_dir=working_dir)
# Set up LLM clients
graph_gen.synthesizer_llm_client = OpenAIModel(
@@ -47,8 +43,8 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
api_key=env.get("SYNTHESIZER_API_KEY", ""),
request_limit=True,
- rpm= RPM(env.get("RPM", 1000)),
- tpm= TPM(env.get("TPM", 50000)),
+ rpm=RPM(env.get("RPM", 1000)),
+ tpm=TPM(env.get("TPM", 50000)),
)
graph_gen.trainee_llm_client = OpenAIModel(
@@ -56,16 +52,15 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
base_url=env.get("TRAINEE_BASE_URL", ""),
api_key=env.get("TRAINEE_API_KEY", ""),
request_limit=True,
- rpm= RPM(env.get("RPM", 1000)),
- tpm= TPM(env.get("TPM", 50000)),
+ rpm=RPM(env.get("RPM", 1000)),
+ tpm=TPM(env.get("TPM", 50000)),
)
- graph_gen.tokenizer_instance = Tokenizer(
- config.get("tokenizer", "cl100k_base"))
+ graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
strategy_config = config.get("traverse_strategy", {})
graph_gen.traverse_strategy = TraverseStrategy(
- qa_form=config.get("qa_form"),
+ qa_form=strategy_config.get("qa_form"),
expand_method=strategy_config.get("expand_method"),
bidirectional=strategy_config.get("bidirectional"),
max_extra_edges=strategy_config.get("max_extra_edges"),
@@ -73,11 +68,12 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
max_depth=strategy_config.get("max_depth"),
edge_sampling=strategy_config.get("edge_sampling"),
isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
- loss_strategy=str(strategy_config.get("loss_strategy"))
+ loss_strategy=str(strategy_config.get("loss_strategy")),
)
return graph_gen
+
# pylint: disable=too-many-statements
def run_graphgen(params, progress=gr.Progress()):
def sum_tokens(client):
@@ -87,10 +83,9 @@ def run_graphgen(params, progress=gr.Progress()):
"if_trainee_model": params.if_trainee_model,
"input_file": params.input_file,
"tokenizer": params.tokenizer,
- "qa_form": params.qa_form,
- "web_search": False,
"quiz_samples": params.quiz_samples,
"traverse_strategy": {
+ "qa_form": params.qa_form,
"bidirectional": params.bidirectional,
"expand_method": params.expand_method,
"max_extra_edges": params.max_extra_edges,
@@ -98,7 +93,7 @@ def run_graphgen(params, progress=gr.Progress()):
"max_depth": params.max_depth,
"edge_sampling": params.edge_sampling,
"isolated_node_strategy": params.isolated_node_strategy,
- "loss_strategy": params.loss_strategy
+ "loss_strategy": params.loss_strategy,
},
"chunk_size": params.chunk_size,
}
@@ -115,11 +110,15 @@ def run_graphgen(params, progress=gr.Progress()):
}
# Test API connection
- test_api_connection(env["SYNTHESIZER_BASE_URL"],
- env["SYNTHESIZER_API_KEY"], env["SYNTHESIZER_MODEL"])
- if config['if_trainee_model']:
- test_api_connection(env["TRAINEE_BASE_URL"],
- env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"])
+ test_api_connection(
+ env["SYNTHESIZER_BASE_URL"],
+ env["SYNTHESIZER_API_KEY"],
+ env["SYNTHESIZER_MODEL"],
+ )
+ if config["if_trainee_model"]:
+ test_api_connection(
+ env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
+ )
# Initialize GraphGen
graph_gen = init_graph_gen(config, env)
@@ -129,7 +128,7 @@ def run_graphgen(params, progress=gr.Progress()):
try:
# Load input data
- file = config['input_file']
+ file = config["input_file"]
if isinstance(file, list):
file = file[0]
@@ -137,24 +136,22 @@ def run_graphgen(params, progress=gr.Progress()):
if file.endswith(".jsonl"):
data_type = "raw"
- with open(file, "r", encoding='utf-8') as f:
+ with open(file, "r", encoding="utf-8") as f:
data.extend(json.loads(line) for line in f)
elif file.endswith(".json"):
data_type = "chunked"
- with open(file, "r", encoding='utf-8') as f:
+ with open(file, "r", encoding="utf-8") as f:
data.extend(json.load(f))
elif file.endswith(".txt"):
# 读取文件后根据chunk_size转成raw格式的数据
data_type = "raw"
content = ""
- with open(file, "r", encoding='utf-8') as f:
+ with open(file, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
content += line.strip() + " "
size = int(config.get("chunk_size", 512))
- chunks = [
- content[i:i + size] for i in range(0, len(content), size)
- ]
+ chunks = [content[i : i + size] for i in range(0, len(content), size)]
data.extend([{"content": chunk} for chunk in chunks])
else:
raise ValueError(f"Unsupported file type: {file}")
@@ -162,9 +159,9 @@ def run_graphgen(params, progress=gr.Progress()):
# Process the data
graph_gen.insert(data, data_type)
- if config['if_trainee_model']:
+ if config["if_trainee_model"]:
# Generate quiz
- graph_gen.quiz(max_samples=config['quiz_samples'])
+ graph_gen.quiz(max_samples=config["quiz_samples"])
# Judge statements
graph_gen.judge()
@@ -174,47 +171,44 @@ def run_graphgen(params, progress=gr.Progress()):
graph_gen.judge(skip=True)
# Traverse graph
- graph_gen.traverse()
+ graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
# Save output
output_data = graph_gen.qa_storage.data
with tempfile.NamedTemporaryFile(
- mode="w",
- suffix=".jsonl",
- delete=False,
- encoding="utf-8") as tmpfile:
+ mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
+ ) as tmpfile:
json.dump(output_data, tmpfile, ensure_ascii=False)
output_file = tmpfile.name
synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
- trainee_tokens = sum_tokens(graph_gen.trainee_llm_client) if config['if_trainee_model'] else 0
+ trainee_tokens = (
+ sum_tokens(graph_gen.trainee_llm_client)
+ if config["if_trainee_model"]
+ else 0
+ )
total_tokens = synthesizer_tokens + trainee_tokens
data_frame = params.token_counter
try:
_update_data = [
- [
- data_frame.iloc[0, 0],
- data_frame.iloc[0, 1],
- str(total_tokens)
- ]
+ [data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
]
- new_df = pd.DataFrame(
- _update_data,
- columns=data_frame.columns
- )
+ new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
data_frame = new_df
except Exception as e:
raise gr.Error(f"DataFrame operation error: {str(e)}")
- return output_file, gr.DataFrame(label='Token Stats',
- headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
- datatype="str",
- interactive=False,
- value=data_frame,
- visible=True,
- wrap=True)
+ return output_file, gr.DataFrame(
+ label="Token Stats",
+ headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
+ datatype="str",
+ interactive=False,
+ value=data_frame,
+ visible=True,
+ wrap=True,
+ )
except Exception as e: # pylint: disable=broad-except
raise gr.Error(f"Error occurred: {str(e)}")
@@ -223,16 +217,18 @@ def run_graphgen(params, progress=gr.Progress()):
# Clean up workspace
cleanup_workspace(graph_gen.working_dir)
-with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
- css=css) as demo):
+
+with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
# Header
- gr.Image(value="https://github.com/open-sciencelab/GraphGen/blob/main/resources/images/logo.png?raw=true",
- label="GraphGen Banner",
- elem_id="banner",
- interactive=False,
- container=False,
- show_download_button=False,
- show_fullscreen_button=False)
+ gr.Image(
+ value=os.path.join(root_dir, "resources", "images", "logo.png"),
+ label="GraphGen Banner",
+ elem_id="banner",
+ interactive=False,
+ container=False,
+ show_download_button=False,
+ show_fullscreen_button=False,
+ )
lang_btn = gr.Radio(
choices=[
("English", "en"),
@@ -245,7 +241,8 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
elem_classes=["center-row"],
)
- gr.HTML("""
+ gr.HTML(
+ """
- """)
+ """
+ )
with Translate(
- os.path.join(root_dir, 'webui', 'translation.json'),
- lang_btn,
- placeholder_langs=["en", "zh"],
- persistant=
- False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
+ os.path.join(root_dir, "webui", "translation.json"),
+ lang_btn,
+ placeholder_langs=["en", "zh"],
+ persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
):
lang_btn.render()
gr.Markdown(
- value = "# " + _("Title") + "\n\n" + \
- "### [GraphGen](https://github.com/open-sciencelab/GraphGen) " + _("Intro")
+ value="# "
+ + _("Title")
+ + "\n\n"
+ + "### [GraphGen](https://github.com/open-sciencelab/GraphGen) "
+ + _("Intro")
)
- if_trainee_model = gr.Checkbox(label=_("Use Trainee Model"),
- value=False,
- interactive=True)
+ if_trainee_model = gr.Checkbox(
+ label=_("Use Trainee Model"), value=False, interactive=True
+ )
with gr.Accordion(label=_("Model Config"), open=False):
- synthesizer_url = gr.Textbox(label="Synthesizer URL",
- value="https://api.siliconflow.cn/v1",
- info=_("Synthesizer URL Info"),
- interactive=True)
- synthesizer_model = gr.Textbox(label="Synthesizer Model",
- value="Qwen/Qwen2.5-7B-Instruct",
- info=_("Synthesizer Model Info"),
- interactive=True)
- trainee_url = gr.Textbox(label="Trainee URL",
- value="https://api.siliconflow.cn/v1",
- info=_("Trainee URL Info"),
- interactive=True,
- visible=if_trainee_model.value is True)
+ synthesizer_url = gr.Textbox(
+ label="Synthesizer URL",
+ value="https://api.siliconflow.cn/v1",
+ info=_("Synthesizer URL Info"),
+ interactive=True,
+ )
+ synthesizer_model = gr.Textbox(
+ label="Synthesizer Model",
+ value="Qwen/Qwen2.5-7B-Instruct",
+ info=_("Synthesizer Model Info"),
+ interactive=True,
+ )
+ trainee_url = gr.Textbox(
+ label="Trainee URL",
+ value="https://api.siliconflow.cn/v1",
+ info=_("Trainee URL Info"),
+ interactive=True,
+ visible=if_trainee_model.value is True,
+ )
trainee_model = gr.Textbox(
label="Trainee Model",
value="Qwen/Qwen2.5-7B-Instruct",
info=_("Trainee Model Info"),
interactive=True,
- visible=if_trainee_model.value is True)
+ visible=if_trainee_model.value is True,
+ )
trainee_api_key = gr.Textbox(
- label=_("SiliconCloud Token for Trainee Model"),
- type="password",
- value="",
- info="https://cloud.siliconflow.cn/account/ak",
- visible=if_trainee_model.value is True)
-
+ label=_("SiliconFlow Token for Trainee Model"),
+ type="password",
+ value="",
+ info="https://cloud.siliconflow.cn/account/ak",
+ visible=if_trainee_model.value is True,
+ )
with gr.Accordion(label=_("Generation Config"), open=False):
- chunk_size = gr.Slider(label="Chunk Size",
- minimum=256,
- maximum=4096,
- value=512,
- step=256,
- interactive=True)
- tokenizer = gr.Textbox(label="Tokenizer",
- value="cl100k_base",
- interactive=True)
- qa_form = gr.Radio(choices=["atomic", "multi_hop", "aggregated"],
- label="QA Form",
- value="aggregated",
- interactive=True)
- quiz_samples = gr.Number(label="Quiz Samples",
- value=2,
- minimum=1,
- interactive=True,
- visible=if_trainee_model.value is True)
- bidirectional = gr.Checkbox(label="Bidirectional",
- value=True,
- interactive=True)
-
- expand_method = gr.Radio(choices=["max_width", "max_tokens"],
- label="Expand Method",
- value="max_tokens",
- interactive=True)
+ chunk_size = gr.Slider(
+ label="Chunk Size",
+ minimum=256,
+ maximum=4096,
+ value=512,
+ step=256,
+ interactive=True,
+ )
+ tokenizer = gr.Textbox(
+ label="Tokenizer", value="cl100k_base", interactive=True
+ )
+ qa_form = gr.Radio(
+ choices=["atomic", "multi_hop", "aggregated"],
+ label="QA Form",
+ value="aggregated",
+ interactive=True,
+ )
+ quiz_samples = gr.Number(
+ label="Quiz Samples",
+ value=2,
+ minimum=1,
+ interactive=True,
+ visible=if_trainee_model.value is True,
+ )
+ bidirectional = gr.Checkbox(
+ label="Bidirectional", value=True, interactive=True
+ )
+
+ expand_method = gr.Radio(
+ choices=["max_width", "max_tokens"],
+ label="Expand Method",
+ value="max_tokens",
+ interactive=True,
+ )
max_extra_edges = gr.Slider(
minimum=1,
maximum=10,
@@ -341,44 +356,54 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
label="Max Extra Edges",
step=1,
interactive=True,
- visible=expand_method.value == "max_width")
- max_tokens = gr.Slider(minimum=64,
- maximum=1024,
- value=256,
- label="Max Tokens",
- step=64,
- interactive=True,
- visible=(expand_method.value
- != "max_width"))
-
- max_depth = gr.Slider(minimum=1,
- maximum=5,
- value=2,
- label="Max Depth",
- step=1,
- interactive=True)
+ visible=expand_method.value == "max_width",
+ )
+ max_tokens = gr.Slider(
+ minimum=64,
+ maximum=1024,
+ value=256,
+ label="Max Tokens",
+ step=64,
+ interactive=True,
+ visible=(expand_method.value != "max_width"),
+ )
+
+ max_depth = gr.Slider(
+ minimum=1,
+ maximum=5,
+ value=2,
+ label="Max Depth",
+ step=1,
+ interactive=True,
+ )
edge_sampling = gr.Radio(
choices=["max_loss", "min_loss", "random"],
label="Edge Sampling",
value="max_loss",
interactive=True,
- visible=if_trainee_model.value is True)
- isolated_node_strategy = gr.Radio(choices=["add", "ignore"],
- label="Isolated Node Strategy",
- value="ignore",
- interactive=True)
- loss_strategy = gr.Radio(choices=["only_edge", "both"],
- label="Loss Strategy",
- value="only_edge",
- interactive=True)
+ visible=if_trainee_model.value is True,
+ )
+ isolated_node_strategy = gr.Radio(
+ choices=["add", "ignore"],
+ label="Isolated Node Strategy",
+ value="ignore",
+ interactive=True,
+ )
+ loss_strategy = gr.Radio(
+ choices=["only_edge", "both"],
+ label="Loss Strategy",
+ value="only_edge",
+ interactive=True,
+ )
with gr.Row(equal_height=True):
with gr.Column(scale=3):
api_key = gr.Textbox(
- label=_("SiliconCloud Token"),
+ label=_("SiliconFlow Token"),
type="password",
value="",
- info="https://cloud.siliconflow.cn/account/ak")
+ info="https://cloud.siliconflow.cn/account/ak",
+ )
with gr.Column(scale=1):
test_connection_btn = gr.Button(_("Test Connection"))
@@ -392,7 +417,8 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
value=1000,
step=100,
interactive=True,
- visible=True)
+ visible=True,
+ )
with gr.Column():
tpm = gr.Slider(
label="TPM",
@@ -401,8 +427,8 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
value=50000,
step=1000,
interactive=True,
- visible=True)
-
+ visible=True,
+ )
with gr.Blocks():
with gr.Row(equal_height=True):
@@ -413,15 +439,17 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
file_types=[".txt", ".json", ".jsonl"],
interactive=True,
)
- examples_dir = os.path.join(root_dir, 'webui', 'examples')
- gr.Examples(examples=[
- [os.path.join(examples_dir, "txt_demo.txt")],
- [os.path.join(examples_dir, "raw_demo.jsonl")],
- [os.path.join(examples_dir, "chunked_demo.json")],
- ],
- inputs=upload_file,
- label=_("Example Files"),
- examples_per_page=3)
+ examples_dir = os.path.join(root_dir, "webui", "examples")
+ gr.Examples(
+ examples=[
+ [os.path.join(examples_dir, "txt_demo.txt")],
+ [os.path.join(examples_dir, "raw_demo.jsonl")],
+ [os.path.join(examples_dir, "chunked_demo.json")],
+ ],
+ inputs=upload_file,
+ label=_("Example Files"),
+ examples_per_page=3,
+ )
with gr.Column(scale=1):
output = gr.File(
label="Output(See Github FAQ)",
@@ -430,12 +458,18 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
)
with gr.Blocks():
- token_counter = gr.DataFrame(label='Token Stats',
- headers=["Source Text Token Count", "Estimated Token Usage", "Token Used"],
- datatype="str",
- interactive=False,
- visible=False,
- wrap=True)
+ token_counter = gr.DataFrame(
+ label="Token Stats",
+ headers=[
+ "Source Text Token Count",
+ "Estimated Token Usage",
+ "Token Used",
+ ],
+ datatype="str",
+ interactive=False,
+ visible=False,
+ wrap=True,
+ )
submit_btn = gr.Button(_("Run GraphGen"))
@@ -443,23 +477,36 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
test_connection_btn.click(
test_api_connection,
inputs=[synthesizer_url, api_key, synthesizer_model],
- outputs=[])
+ outputs=[],
+ )
if if_trainee_model.value:
- test_connection_btn.click(test_api_connection,
- inputs=[trainee_url, api_key, trainee_model],
- outputs=[])
+ test_connection_btn.click(
+ test_api_connection,
+ inputs=[trainee_url, api_key, trainee_model],
+ outputs=[],
+ )
- expand_method.change(lambda method:
- (gr.update(visible=method == "max_width"),
- gr.update(visible=method != "max_width")),
- inputs=expand_method,
- outputs=[max_extra_edges, max_tokens])
+ expand_method.change(
+ lambda method: (
+ gr.update(visible=method == "max_width"),
+ gr.update(visible=method != "max_width"),
+ ),
+ inputs=expand_method,
+ outputs=[max_extra_edges, max_tokens],
+ )
if_trainee_model.change(
lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
inputs=if_trainee_model,
- outputs=[trainee_url, trainee_model, quiz_samples, edge_sampling, trainee_api_key])
+ outputs=[
+ trainee_url,
+ trainee_model,
+ quiz_samples,
+ edge_sampling,
+ trainee_api_key,
+ ],
+ )
upload_file.change(
lambda x: (gr.update(visible=True)),
@@ -479,41 +526,61 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
)
submit_btn.click(
- lambda *args: run_graphgen(GraphGenParams(
- if_trainee_model=args[0],
- input_file=args[1],
- tokenizer=args[2],
- qa_form=args[3],
- bidirectional=args[4],
- expand_method=args[5],
- max_extra_edges=args[6],
- max_tokens=args[7],
- max_depth=args[8],
- edge_sampling=args[9],
- isolated_node_strategy=args[10],
- loss_strategy=args[11],
- synthesizer_url=args[12],
- synthesizer_model=args[13],
- trainee_model=args[14],
- api_key=args[15],
- chunk_size=args[16],
- rpm=args[17],
- tpm=args[18],
- quiz_samples=args[19],
- trainee_url=args[20],
- trainee_api_key=args[21],
- token_counter=args[22],
- )),
+ lambda *args: run_graphgen(
+ GraphGenParams(
+ if_trainee_model=args[0],
+ input_file=args[1],
+ tokenizer=args[2],
+ qa_form=args[3],
+ bidirectional=args[4],
+ expand_method=args[5],
+ max_extra_edges=args[6],
+ max_tokens=args[7],
+ max_depth=args[8],
+ edge_sampling=args[9],
+ isolated_node_strategy=args[10],
+ loss_strategy=args[11],
+ synthesizer_url=args[12],
+ synthesizer_model=args[13],
+ trainee_model=args[14],
+ api_key=args[15],
+ chunk_size=args[16],
+ rpm=args[17],
+ tpm=args[18],
+ quiz_samples=args[19],
+ trainee_url=args[20],
+ trainee_api_key=args[21],
+ token_counter=args[22],
+ )
+ ),
inputs=[
- if_trainee_model, upload_file, tokenizer, qa_form,
- bidirectional, expand_method, max_extra_edges, max_tokens,
- max_depth, edge_sampling, isolated_node_strategy,
- loss_strategy, synthesizer_url, synthesizer_model, trainee_model,
- api_key, chunk_size, rpm, tpm, quiz_samples, trainee_url, trainee_api_key, token_counter
+ if_trainee_model,
+ upload_file,
+ tokenizer,
+ qa_form,
+ bidirectional,
+ expand_method,
+ max_extra_edges,
+ max_tokens,
+ max_depth,
+ edge_sampling,
+ isolated_node_strategy,
+ loss_strategy,
+ synthesizer_url,
+ synthesizer_model,
+ trainee_model,
+ api_key,
+ chunk_size,
+ rpm,
+ tpm,
+ quiz_samples,
+ trainee_url,
+ trainee_api_key,
+ token_counter,
],
outputs=[output, token_counter],
)
if __name__ == "__main__":
demo.queue(api_open=False, default_concurrency_limit=2)
- demo.launch(server_name='0.0.0.0')
+ demo.launch(server_name="0.0.0.0")
diff --git a/graphgen/configs/README.md b/graphgen/configs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..afa815cd876cbba3db687d09f210288a24efc2b6
--- /dev/null
+++ b/graphgen/configs/README.md
@@ -0,0 +1 @@
+# Configs for GraphGen
diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e13a660653926416473660cb9ca1d9dd6bf43c4b
--- /dev/null
+++ b/graphgen/configs/aggregated_config.yaml
@@ -0,0 +1,21 @@
+input_data_type: raw # raw, chunked
+input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
+output_data_type: aggregated # atomic, aggregated, multi_hop, cot
+output_data_format: ChatML # Alpaca, Sharegpt, ChatML
+tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
+search: # web search configuration
+ enabled: false # whether to enable web search
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
+quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
+ enabled: true
+ quiz_samples: 2 # number of quiz samples to generate
+ re_judge: false # whether to re-judge the existing quiz samples
+traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
+ bidirectional: true # whether to traverse the graph in both directions
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
+ expand_method: max_width # expand method, support: max_width, max_depth
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
+ max_depth: 5 # maximum depth for graph traversal
+ max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e2c081f2b8e012df5d6029d2f742fd5335010bd
--- /dev/null
+++ b/graphgen/configs/atomic_config.yaml
@@ -0,0 +1,21 @@
+input_data_type: raw # raw, chunked
+input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
+output_data_type: atomic # atomic, aggregated, multi_hop, cot
+output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
+tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
+search: # web search configuration
+ enabled: false # whether to enable web search
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
+quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
+ enabled: true
+ quiz_samples: 2 # number of quiz samples to generate
+ re_judge: false # whether to re-judge the existing quiz samples
+traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
+ bidirectional: true # whether to traverse the graph in both directions
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
+ expand_method: max_width # expand method, support: max_width, max_depth
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
+ max_depth: 3 # maximum depth for graph traversal
+ max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
diff --git a/graphgen/configs/config.yaml.example b/graphgen/configs/config.yaml.example
deleted file mode 100644
index eeb804af544f01167941f3d8f240131e3fc25312..0000000000000000000000000000000000000000
--- a/graphgen/configs/config.yaml.example
+++ /dev/null
@@ -1,16 +0,0 @@
-data_type: raw
-input_file: resources/examples/raw_demo.jsonl
-tokenizer: cl100k_base
-quiz_samples: 2
-traverse_strategy:
- qa_form: atomic
- bidirectional: true
- edge_sampling: max_loss
- expand_method: max_tokens
- isolated_node_strategy: add
- max_depth: 2
- max_extra_edges: 5
- max_tokens: 256
- loss_strategy: only_edge
-web_search: false
-re_judge: false
diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1073e97dd2566050a73565f202a33c9594269633
--- /dev/null
+++ b/graphgen/configs/cot_config.yaml
@@ -0,0 +1,13 @@
+input_data_type: raw # raw, chunked
+input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
+output_data_type: cot # atomic, aggregated, multi_hop, cot
+output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
+tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
+search: # web search configuration
+ enabled: false # whether to enable web search
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
+method_params:
+ method: leiden
+ max_size: 20 # Maximum size of communities
+ use_lcc: false
+ random_seed: 42
diff --git a/graphgen/configs/graphgen_config.yaml b/graphgen/configs/graphgen_config.yaml
deleted file mode 100644
index 4ddb66c7ac6db7c69356c1aa4a77a9423cd35b92..0000000000000000000000000000000000000000
--- a/graphgen/configs/graphgen_config.yaml
+++ /dev/null
@@ -1,16 +0,0 @@
-data_type: raw
-input_file: resources/examples/raw_demo.jsonl
-tokenizer: cl100k_base
-quiz_samples: 2
-traverse_strategy:
- qa_form: aggregated
- bidirectional: true
- edge_sampling: max_loss
- expand_method: max_width
- isolated_node_strategy: ignore
- max_depth: 1
- max_extra_edges: 2
- max_tokens: 256
- loss_strategy: only_edge
-web_search: false
-re_judge: false
diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..530edcd1308b5e051d12a6dffd2fb993031f834f
--- /dev/null
+++ b/graphgen/configs/multi_hop_config.yaml
@@ -0,0 +1,21 @@
+input_data_type: raw # raw, chunked
+input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
+output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
+output_data_format: ChatML # Alpaca, Sharegpt, ChatML
+tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
+search: # web search configuration
+ enabled: false # whether to enable web search
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
+quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
+ enabled: true
+ quiz_samples: 2 # number of quiz samples to generate
+ re_judge: false # whether to re-judge the existing quiz samples
+traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
+ bidirectional: true # whether to traverse the graph in both directions
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
+ expand_method: max_width # expand method, support: max_width, max_depth
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
+ max_depth: 1 # maximum depth for graph traversal
+ max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
diff --git a/graphgen/generate.py b/graphgen/generate.py
index 14693471c4997346db428cb09ab9b83f57453e0d..eec168d6137cdc2ebaed8bab3ffef74adf70ddf3 100644
--- a/graphgen/generate.py
+++ b/graphgen/generate.py
@@ -1,101 +1,103 @@
+import argparse
import os
-import json
import time
-import argparse
from importlib.resources import files
+
import yaml
from dotenv import load_dotenv
from .graphgen import GraphGen
-from .models import OpenAIModel, Tokenizer, TraverseStrategy
-from .utils import set_logger
+from .utils import logger, set_logger
sys_path = os.path.abspath(os.path.dirname(__file__))
load_dotenv()
+
def set_working_dir(folder):
os.makedirs(folder, exist_ok=True)
os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
+
def save_config(config_path, global_config):
if not os.path.exists(os.path.dirname(config_path)):
os.makedirs(os.path.dirname(config_path))
- with open(config_path, "w", encoding='utf-8') as config_file:
- yaml.dump(global_config, config_file, default_flow_style=False, allow_unicode=True)
+ with open(config_path, "w", encoding="utf-8") as config_file:
+ yaml.dump(
+ global_config, config_file, default_flow_style=False, allow_unicode=True
+ )
+
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('--config_file',
- help='Config parameters for GraphGen.',
- # default=os.path.join(sys_path, "configs", "graphgen_config.yaml"),
- default=files('graphgen').joinpath("configs", "graphgen_config.yaml"),
- type=str)
- parser.add_argument('--output_dir',
- help='Output directory for GraphGen.',
- default=sys_path,
- required=True,
- type=str)
+ parser.add_argument(
+ "--config_file",
+ help="Config parameters for GraphGen.",
+ default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
+ type=str,
+ )
+ parser.add_argument(
+ "--output_dir",
+ help="Output directory for GraphGen.",
+ default=sys_path,
+ required=True,
+ type=str,
+ )
args = parser.parse_args()
working_dir = args.output_dir
set_working_dir(working_dir)
- unique_id = int(time.time())
- set_logger(os.path.join(working_dir, "logs", f"graphgen_{unique_id}.log"), if_stream=False)
- with open(args.config_file, "r", encoding='utf-8') as f:
+ with open(args.config_file, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
- input_file = config['input_file']
-
- if config['data_type'] == 'raw':
- with open(input_file, "r", encoding='utf-8') as f:
- data = [json.loads(line) for line in f]
- elif config['data_type'] == 'chunked':
- with open(input_file, "r", encoding='utf-8') as f:
- data = json.load(f)
- else:
- raise ValueError(f"Invalid data type: {config['data_type']}")
-
- synthesizer_llm_client = OpenAIModel(
- model_name=os.getenv("SYNTHESIZER_MODEL"),
- api_key=os.getenv("SYNTHESIZER_API_KEY"),
- base_url=os.getenv("SYNTHESIZER_BASE_URL")
- )
- trainee_llm_client = OpenAIModel(
- model_name=os.getenv("TRAINEE_MODEL"),
- api_key=os.getenv("TRAINEE_API_KEY"),
- base_url=os.getenv("TRAINEE_BASE_URL")
- )
-
- traverse_strategy = TraverseStrategy(
- **config['traverse_strategy']
+ output_data_type = config["output_data_type"]
+ unique_id = int(time.time())
+ set_logger(
+ os.path.join(
+ working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
+ ),
+ if_stream=True,
)
-
- graph_gen = GraphGen(
- working_dir=working_dir,
- unique_id=unique_id,
- synthesizer_llm_client=synthesizer_llm_client,
- trainee_llm_client=trainee_llm_client,
- if_web_search=config['web_search'],
- tokenizer_instance=Tokenizer(
- model_name=config['tokenizer']
+ logger.info(
+ "GraphGen with unique ID %s logging to %s",
+ unique_id,
+ os.path.join(
+ working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
),
- traverse_strategy=traverse_strategy
)
- graph_gen.insert(data, config['data_type'])
-
- graph_gen.quiz(max_samples=config['quiz_samples'])
-
- graph_gen.judge(re_judge=config["re_judge"])
+ graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
+
+ graph_gen.insert()
+
+ if config["search"]["enabled"]:
+ graph_gen.search()
+
+ # Use pipeline according to the output data type
+ if output_data_type in ["atomic", "aggregated", "multi_hop"]:
+ if "quiz_and_judge_strategy" in config and config[
+ "quiz_and_judge_strategy"
+ ].get("enabled", False):
+ graph_gen.quiz()
+ graph_gen.judge()
+ else:
+ logger.warning(
+ "Quiz and Judge strategy is disabled. Edge sampling falls back to random."
+ )
+ graph_gen.traverse_strategy.edge_sampling = "random"
+ graph_gen.traverse()
+ elif output_data_type == "cot":
+ graph_gen.generate_reasoning(method_params=config["method_params"])
+ else:
+ raise ValueError(f"Unsupported output data type: {output_data_type}")
- graph_gen.traverse()
+ output_path = os.path.join(working_dir, "data", "graphgen", str(unique_id))
+ save_config(os.path.join(output_path, f"config-{unique_id}.yaml"), config)
+ logger.info("GraphGen completed successfully. Data saved to %s", output_path)
- path = os.path.join(working_dir, "data", "graphgen", str(unique_id), f"config-{unique_id}.yaml")
- save_config(path, config)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py
index 265d32a998592c5b1ad1af60cdb689801b621bb0..7b7b302ac5c2034b631689c3313389aa32160572 100644
--- a/graphgen/graphgen.py
+++ b/graphgen/graphgen.py
@@ -1,10 +1,8 @@
-# Adapt from https://github.com/HKUDS/LightRAG
-
import asyncio
import os
import time
from dataclasses import dataclass, field
-from typing import List, Union, cast
+from typing import Dict, List, Union, cast
import gradio as gr
from tqdm.asyncio import tqdm as tqdm_async
@@ -12,85 +10,124 @@ from tqdm.asyncio import tqdm as tqdm_async
from .models import (
Chunk,
JsonKVStorage,
+ JsonListStorage,
NetworkXStorage,
OpenAIModel,
Tokenizer,
TraverseStrategy,
- WikiSearch,
)
from .models.storage.base_storage import StorageNameSpace
from .operators import (
extract_kg,
+ generate_cot,
judge_statement,
quiz,
- search_wikipedia,
- skip_judge_statement,
+ search_all,
traverse_graph_atomically,
traverse_graph_by_edge,
traverse_graph_for_multi_hop,
)
-from .utils import compute_content_hash, create_event_loop, logger
+from .utils import (
+ compute_content_hash,
+ create_event_loop,
+ format_generation_results,
+ logger,
+ read_file,
+)
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+
@dataclass
class GraphGen:
unique_id: int = int(time.time())
working_dir: str = os.path.join(sys_path, "cache")
-
- # text chunking
- chunk_size: int = 1024
- chunk_overlap_size: int = 100
+ config: Dict = field(default_factory=dict)
# llm
+ tokenizer_instance: Tokenizer = None
synthesizer_llm_client: OpenAIModel = None
trainee_llm_client: OpenAIModel = None
- tokenizer_instance: Tokenizer = None
- # web search
- if_web_search: bool = False
- wiki_client: WikiSearch = field(default_factory=WikiSearch)
+ # text chunking
+ # TODO: make it configurable
+ chunk_size: int = 1024
+ chunk_overlap_size: int = 100
+
+ # search
+ search_config: dict = field(
+ default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
+ )
- # traverse strategy
- traverse_strategy: TraverseStrategy = field(default_factory=TraverseStrategy)
+ # traversal
+ traverse_strategy: TraverseStrategy = None
# webui
progress_bar: gr.Progress = None
def __post_init__(self):
+ self.tokenizer_instance: Tokenizer = Tokenizer(
+ model_name=self.config["tokenizer"]
+ )
+ self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
+ model_name=os.getenv("SYNTHESIZER_MODEL"),
+ api_key=os.getenv("SYNTHESIZER_API_KEY"),
+ base_url=os.getenv("SYNTHESIZER_BASE_URL"),
+ tokenizer_instance=self.tokenizer_instance,
+ )
+ self.trainee_llm_client: OpenAIModel = OpenAIModel(
+ model_name=os.getenv("TRAINEE_MODEL"),
+ api_key=os.getenv("TRAINEE_API_KEY"),
+ base_url=os.getenv("TRAINEE_BASE_URL"),
+ tokenizer_instance=self.tokenizer_instance,
+ )
+ self.search_config = self.config["search"]
+
+ if "traverse_strategy" in self.config:
+ self.traverse_strategy = TraverseStrategy(
+ **self.config["traverse_strategy"]
+ )
+
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="full_docs"
)
self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="text_chunks"
)
- self.wiki_storage: JsonKVStorage = JsonKVStorage(
- self.working_dir, namespace="wiki"
- )
self.graph_storage: NetworkXStorage = NetworkXStorage(
self.working_dir, namespace="graph"
)
+ self.search_storage: JsonKVStorage = JsonKVStorage(
+ self.working_dir, namespace="search"
+ )
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="rephrase"
)
- self.qa_storage: JsonKVStorage = JsonKVStorage(
- os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)), namespace=f"qa-{self.unique_id}"
+ self.qa_storage: JsonListStorage = JsonListStorage(
+ os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)),
+ namespace=f"qa-{self.unique_id}",
)
- async def async_split_chunks(self, data: Union[List[list], List[dict]], data_type: str) -> dict:
- # TODO: 是否进行指代消解
+ async def async_split_chunks(
+ self, data: List[Union[List, Dict]], data_type: str
+ ) -> dict:
+ # TODO: configurable whether to use coreference resolution
if len(data) == 0:
return {}
- new_docs = {}
inserting_chunks = {}
if data_type == "raw":
assert isinstance(data, list) and isinstance(data[0], dict)
# compute hash for each document
new_docs = {
- compute_content_hash(doc['content'], prefix="doc-"): {'content': doc['content']} for doc in data
+ compute_content_hash(doc["content"], prefix="doc-"): {
+ "content": doc["content"]
+ }
+ for doc in data
}
- _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
+ _add_doc_keys = await self.full_docs_storage.filter_keys(
+ list(new_docs.keys())
+ )
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if len(new_docs) == 0:
logger.warning("All docs are already in the storage")
@@ -100,63 +137,83 @@ class GraphGen:
cur_index = 1
doc_number = len(new_docs)
async for doc_key, doc in tqdm_async(
- new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
- ):
+ new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
+ ):
chunks = {
compute_content_hash(dp["content"], prefix="chunk-"): {
**dp,
- 'full_doc_id': doc_key
- } for dp in self.tokenizer_instance.chunk_by_token_size(doc["content"],
- self.chunk_overlap_size, self.chunk_size)
+ "full_doc_id": doc_key,
+ }
+ for dp in self.tokenizer_instance.chunk_by_token_size(
+ doc["content"], self.chunk_overlap_size, self.chunk_size
+ )
}
inserting_chunks.update(chunks)
if self.progress_bar is not None:
- self.progress_bar(
- cur_index / doc_number, f"Chunking {doc_key}"
- )
+ self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
cur_index += 1
- _add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys()))
- inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys}
+ _add_chunk_keys = await self.text_chunks_storage.filter_keys(
+ list(inserting_chunks.keys())
+ )
+ inserting_chunks = {
+ k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
+ }
elif data_type == "chunked":
assert isinstance(data, list) and isinstance(data[0], list)
new_docs = {
- compute_content_hash("".join(chunk['content']), prefix="doc-"): {'content': "".join(chunk['content'])}
- for doc in data for chunk in doc
+ compute_content_hash("".join(chunk["content"]), prefix="doc-"): {
+ "content": "".join(chunk["content"])
+ }
+ for doc in data
+ for chunk in doc
}
- _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
+ _add_doc_keys = await self.full_docs_storage.filter_keys(
+ list(new_docs.keys())
+ )
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if len(new_docs) == 0:
logger.warning("All docs are already in the storage")
return {}
logger.info("[New Docs] inserting %d docs", len(new_docs))
- async for doc in tqdm_async(data, desc="[1/4]Chunking documents", unit="doc"):
- doc_str = "".join([chunk['content'] for chunk in doc])
+ async for doc in tqdm_async(
+ data, desc="[1/4]Chunking documents", unit="doc"
+ ):
+ doc_str = "".join([chunk["content"] for chunk in doc])
for chunk in doc:
- chunk_key = compute_content_hash(chunk['content'], prefix="chunk-")
+ chunk_key = compute_content_hash(chunk["content"], prefix="chunk-")
inserting_chunks[chunk_key] = {
**chunk,
- 'full_doc_id': compute_content_hash(doc_str, prefix="doc-")
+ "full_doc_id": compute_content_hash(doc_str, prefix="doc-"),
}
- _add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys()))
- inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys}
+ _add_chunk_keys = await self.text_chunks_storage.filter_keys(
+ list(inserting_chunks.keys())
+ )
+ inserting_chunks = {
+ k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
+ }
+ else:
+ raise ValueError(f"Unknown data type: {data_type}")
await self.full_docs_storage.upsert(new_docs)
await self.text_chunks_storage.upsert(inserting_chunks)
return inserting_chunks
- def insert(self, data: Union[List[list], List[dict]], data_type: str):
+ def insert(self):
loop = create_event_loop()
- loop.run_until_complete(self.async_insert(data, data_type))
+ loop.run_until_complete(self.async_insert())
- async def async_insert(self, data: Union[List[list], List[dict]], data_type: str):
+ async def async_insert(self):
"""
-
insert chunks into the graph
"""
+ input_file = self.config["input_file"]
+ data_type = self.config["input_data_type"]
+ data = read_file(input_file)
+
inserting_chunks = await self.async_split_chunks(data, data_type)
if len(inserting_chunks) == 0:
@@ -169,52 +226,96 @@ class GraphGen:
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
tokenizer_instance=self.tokenizer_instance,
- chunks=[Chunk(id=k, content=v['content']) for k, v in inserting_chunks.items()],
- progress_bar = self.progress_bar,
+ chunks=[
+ Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
+ ],
+ progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning("No entities or relations extracted")
return
- logger.info("[Wiki Search] is %s", 'enabled' if self.if_web_search else 'disabled')
- if self.if_web_search:
- logger.info("[Wiki Search]...")
- _add_wiki_data = await search_wikipedia(
- llm_client= self.synthesizer_llm_client,
- wiki_search_client=self.wiki_client,
- knowledge_graph_instance=_add_entities_and_relations
- )
- await self.wiki_storage.upsert(_add_wiki_data)
-
await self._insert_done()
async def _insert_done(self):
tasks = []
- for storage_instance in [self.full_docs_storage, self.text_chunks_storage,
- self.graph_storage, self.wiki_storage]:
+ for storage_instance in [
+ self.full_docs_storage,
+ self.text_chunks_storage,
+ self.graph_storage,
+ self.search_storage,
+ ]:
if storage_instance is None:
continue
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
await asyncio.gather(*tasks)
- def quiz(self, max_samples=1):
+ def search(self):
loop = create_event_loop()
- loop.run_until_complete(self.async_quiz(max_samples))
+ loop.run_until_complete(self.async_search())
- async def async_quiz(self, max_samples=1):
- await quiz(self.synthesizer_llm_client, self.graph_storage, self.rephrase_storage, max_samples)
- await self.rephrase_storage.index_done_callback()
+ async def async_search(self):
+ logger.info(
+ "Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
+ )
+ if self.search_config["enabled"]:
+ logger.info(
+ "[Search] %s ...", ", ".join(self.search_config["search_types"])
+ )
+ all_nodes = await self.graph_storage.get_all_nodes()
+ all_nodes_names = [node[0] for node in all_nodes]
+ new_search_entities = await self.full_docs_storage.filter_keys(
+ all_nodes_names
+ )
+ logger.info(
+ "[Search] Found %d entities to search", len(new_search_entities)
+ )
+ _add_search_data = await search_all(
+ search_types=self.search_config["search_types"],
+ search_entities=new_search_entities,
+ )
+ if _add_search_data:
+ await self.search_storage.upsert(_add_search_data)
+ logger.info("[Search] %d entities searched", len(_add_search_data))
+
+ # Format search results for inserting
+ search_results = []
+ for _, search_data in _add_search_data.items():
+ search_results.extend(
+ [
+ {"content": search_data[key]}
+ for key in list(search_data.keys())
+ ]
+ )
+ # TODO: fix insert after search
+ await self.async_insert()
- def judge(self, re_judge=False, skip=False):
+ def quiz(self):
loop = create_event_loop()
- loop.run_until_complete(self.async_judge(re_judge, skip))
+ loop.run_until_complete(self.async_quiz())
+
+ async def async_quiz(self):
+ max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
+ await quiz(
+ self.synthesizer_llm_client,
+ self.graph_storage,
+ self.rephrase_storage,
+ max_samples,
+ )
+ await self.rephrase_storage.index_done_callback()
- async def async_judge(self, re_judge=False, skip=False):
- if skip:
- _update_relations = await skip_judge_statement(self.graph_storage)
- else:
- _update_relations = await judge_statement(self.trainee_llm_client, self.graph_storage,
- self.rephrase_storage, re_judge)
+ def judge(self):
+ loop = create_event_loop()
+ loop.run_until_complete(self.async_judge())
+
+ async def async_judge(self):
+ re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
+ _update_relations = await judge_statement(
+ self.trainee_llm_client,
+ self.graph_storage,
+ self.rephrase_storage,
+ re_judge,
+ )
await _update_relations.index_done_callback()
def traverse(self):
@@ -222,26 +323,60 @@ class GraphGen:
loop.run_until_complete(self.async_traverse())
async def async_traverse(self):
- if self.traverse_strategy.qa_form == "atomic":
- results = await traverse_graph_atomically(self.synthesizer_llm_client,
- self.tokenizer_instance,
- self.graph_storage,
- self.traverse_strategy,
- self.text_chunks_storage,
- self.progress_bar)
- elif self.traverse_strategy.qa_form == "multi_hop":
- results = await traverse_graph_for_multi_hop(self.synthesizer_llm_client,
- self.tokenizer_instance,
- self.graph_storage,
- self.traverse_strategy,
- self.text_chunks_storage,
- self.progress_bar)
- elif self.traverse_strategy.qa_form == "aggregated":
- results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance,
- self.graph_storage, self.traverse_strategy, self.text_chunks_storage,
- self.progress_bar)
+ output_data_type = self.config["output_data_type"]
+
+ if output_data_type == "atomic":
+ results = await traverse_graph_atomically(
+ self.synthesizer_llm_client,
+ self.tokenizer_instance,
+ self.graph_storage,
+ self.traverse_strategy,
+ self.text_chunks_storage,
+ self.progress_bar,
+ )
+ elif output_data_type == "multi_hop":
+ results = await traverse_graph_for_multi_hop(
+ self.synthesizer_llm_client,
+ self.tokenizer_instance,
+ self.graph_storage,
+ self.traverse_strategy,
+ self.text_chunks_storage,
+ self.progress_bar,
+ )
+ elif output_data_type == "aggregated":
+ results = await traverse_graph_by_edge(
+ self.synthesizer_llm_client,
+ self.tokenizer_instance,
+ self.graph_storage,
+ self.traverse_strategy,
+ self.text_chunks_storage,
+ self.progress_bar,
+ )
else:
- raise ValueError(f"Unknown qa_form: {self.traverse_strategy.qa_form}")
+ raise ValueError(f"Unknown qa_form: {output_data_type}")
+
+ results = format_generation_results(
+ results, output_data_format=self.config["output_data_format"]
+ )
+
+ await self.qa_storage.upsert(results)
+ await self.qa_storage.index_done_callback()
+
+ def generate_reasoning(self, method_params):
+ loop = create_event_loop()
+ loop.run_until_complete(self.async_generate_reasoning(method_params))
+
+ async def async_generate_reasoning(self, method_params):
+ results = await generate_cot(
+ self.graph_storage,
+ self.synthesizer_llm_client,
+ method_params=method_params,
+ )
+
+ results = format_generation_results(
+ results, output_data_format=self.config["output_data_format"]
+ )
+
await self.qa_storage.upsert(results)
await self.qa_storage.index_done_callback()
@@ -252,7 +387,7 @@ class GraphGen:
async def async_clear(self):
await self.full_docs_storage.drop()
await self.text_chunks_storage.drop()
- await self.wiki_storage.drop()
+ await self.search_storage.drop()
await self.graph_storage.clear()
await self.rephrase_storage.drop()
await self.qa_storage.drop()
diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py
index c2f9e71405a656ba2203b8cf55f8dc7225d0770b..f715335874578b32366228dc9a67a3703073bbe3 100644
--- a/graphgen/models/__init__.py
+++ b/graphgen/models/__init__.py
@@ -1,22 +1,20 @@
-from .text.chunk import Chunk
-from .text.text_pair import TextPair
-
-from .llm.topk_token_model import Token, TopkTokenModel
-from .llm.openai_model import OpenAIModel
-from .llm.tokenizer import Tokenizer
-
-from .storage.networkx_storage import NetworkXStorage
-from .storage.json_storage import JsonKVStorage
-
-from .search.wiki_search import WikiSearch
-
+from .community.community_detector import CommunityDetector
from .evaluate.length_evaluator import LengthEvaluator
from .evaluate.mtld_evaluator import MTLDEvaluator
from .evaluate.reward_evaluator import RewardEvaluator
from .evaluate.uni_evaluator import UniEvaluator
-
+from .llm.openai_model import OpenAIModel
+from .llm.tokenizer import Tokenizer
+from .llm.topk_token_model import Token, TopkTokenModel
+from .search.db.uniprot_search import UniProtSearch
+from .search.kg.wiki_search import WikiSearch
+from .search.web.bing_search import BingSearch
+from .search.web.google_search import GoogleSearch
+from .storage.json_storage import JsonKVStorage, JsonListStorage
+from .storage.networkx_storage import NetworkXStorage
from .strategy.travserse_strategy import TraverseStrategy
-
+from .text.chunk import Chunk
+from .text.text_pair import TextPair
__all__ = [
# llm models
@@ -28,8 +26,12 @@ __all__ = [
"Chunk",
"NetworkXStorage",
"JsonKVStorage",
+ "JsonListStorage",
# search models
"WikiSearch",
+ "GoogleSearch",
+ "BingSearch",
+ "UniProtSearch",
# evaluate models
"TextPair",
"LengthEvaluator",
@@ -38,4 +40,6 @@ __all__ = [
"UniEvaluator",
# strategy models
"TraverseStrategy",
+ # community models
+ "CommunityDetector",
]
diff --git a/graphgen/models/community/__init__.py b/graphgen/models/community/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/models/community/community_detector.py b/graphgen/models/community/community_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..0041f4c4a4a57648078ebe650a5c9702d7a17eb5
--- /dev/null
+++ b/graphgen/models/community/community_detector.py
@@ -0,0 +1,95 @@
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Dict, List
+
+from graphgen.models.storage.networkx_storage import NetworkXStorage
+
+
+@dataclass
+class CommunityDetector:
+ """Class for community detection algorithms."""
+
+ graph_storage: NetworkXStorage = None
+ method: str = "leiden"
+ method_params: Dict[str, Any] = None
+
+ async def detect_communities(self) -> Dict[str, int]:
+ if self.method == "leiden":
+ return await self._leiden_communities(**self.method_params or {})
+ raise ValueError(f"Unknown community detection method: {self.method}")
+
+ async def get_graph(self):
+ return await self.graph_storage.get_graph()
+
+ async def _leiden_communities(
+ self, max_size: int = None, **kwargs
+ ) -> Dict[str, int]:
+ """
+ Detect communities using the Leiden algorithm.
+ If max_size is given, any community larger than max_size will be split
+ into smaller sub-communities each having at most max_size nodes.
+ """
+ import igraph as ig
+ import networkx as nx
+ from leidenalg import ModularityVertexPartition, find_partition
+
+ graph = await self.get_graph()
+ graph.remove_nodes_from(list(nx.isolates(graph)))
+
+ ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
+
+ random_seed = kwargs.get("random_seed", 42)
+ use_lcc = kwargs.get("use_lcc", False)
+
+ communities: Dict[str, int] = {}
+ if use_lcc:
+ lcc = ig_graph.components().giant()
+ partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
+ for part, cluster in enumerate(partition):
+ for v in cluster:
+ communities[lcc.vs[v]["name"]] = part
+ else:
+ offset = 0
+ for component in ig_graph.components():
+ subgraph = ig_graph.induced_subgraph(component)
+ partition = find_partition(
+ subgraph, ModularityVertexPartition, seed=random_seed
+ )
+ for part, cluster in enumerate(partition):
+ for v in cluster:
+ original_node = subgraph.vs[v]["name"]
+ communities[original_node] = part + offset
+ offset += len(partition)
+
+ # split large communities if max_size is specified
+ if max_size is None or max_size <= 0:
+ return communities
+
+ return await self._split_communities(communities, max_size)
+
+ @staticmethod
+ async def _split_communities(
+ communities: Dict[str, int], max_size: int
+ ) -> Dict[str, int]:
+ """
+ Split communities larger than max_size into smaller sub-communities.
+ """
+ cid2nodes: Dict[int, List[str]] = defaultdict(list)
+ for node, cid in communities.items():
+ cid2nodes[cid].append(node)
+
+ new_communities: Dict[str, int] = {}
+ new_cid = 0
+ for cid, nodes in cid2nodes.items():
+ if len(nodes) <= max_size:
+ for n in nodes:
+ new_communities[n] = new_cid
+ new_cid += 1
+ else:
+ for start in range(0, len(nodes), max_size):
+ sub = nodes[start : start + max_size]
+ for n in sub:
+ new_communities[n] = new_cid
+ new_cid += 1
+
+ return new_communities
diff --git a/graphgen/models/llm/openai_model.py b/graphgen/models/llm/openai_model.py
index 6973c1cec13a3a0842eb50e5eec839c6a16612af..2c04432f1502eb80d90cfa7cd50cd1ddc622e3a5 100644
--- a/graphgen/models/llm/openai_model.py
+++ b/graphgen/models/llm/openai_model.py
@@ -1,18 +1,21 @@
import math
+import re
from dataclasses import dataclass, field
-from typing import List, Dict, Optional
+from typing import Dict, List, Optional
+
import openai
-from openai import AsyncOpenAI, RateLimitError, APIConnectionError, APITimeoutError
+from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
from tenacity import (
retry,
+ retry_if_exception_type,
stop_after_attempt,
wait_exponential,
- retry_if_exception_type,
)
-from graphgen.models.llm.topk_token_model import TopkTokenModel, Token
-from graphgen.models.llm.tokenizer import Tokenizer
from graphgen.models.llm.limitter import RPM, TPM
+from graphgen.models.llm.tokenizer import Tokenizer
+from graphgen.models.llm.topk_token_model import Token, TopkTokenModel
+
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
token_logprobs = response.choices[0].logprobs.content
@@ -20,13 +23,23 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
for token_prob in token_logprobs:
prob = math.exp(token_prob.logprob)
candidate_tokens = [
- Token(t.token, math.exp(t.logprob))
- for t in token_prob.top_logprobs
+ Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs
]
token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
tokens.append(token)
return tokens
+
+def filter_think_tags(text: str) -> str:
+ """
+ Remove tags from the text.
+ If the text contains and , it removes everything between them and the tags themselves.
+ """
+ think_pattern = re.compile(r".*?", re.DOTALL)
+ filtered_text = think_pattern.sub("", text).strip()
+ return filtered_text if filtered_text else text.strip()
+
+
@dataclass
class OpenAIModel(TopkTokenModel):
model_name: str = "gpt-4o-mini"
@@ -42,12 +55,13 @@ class OpenAIModel(TopkTokenModel):
rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
+ tokenizer_instance: Tokenizer = field(default_factory=Tokenizer)
def __post_init__(self):
assert self.api_key is not None, "Please provide api key to access openai api."
- if self.api_key == "":
- self.api_key = "none"
- self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
+ self.client = AsyncOpenAI(
+ api_key=self.api_key or "dummy", base_url=self.base_url
+ )
def _pre_generate(self, text: str, history: List[str]) -> Dict:
kwargs = {
@@ -69,16 +83,19 @@ class OpenAIModel(TopkTokenModel):
assert len(history) % 2 == 0, "History should have even number of elements."
messages = history + messages
- kwargs['messages']= messages
+ kwargs["messages"] = messages
return kwargs
-
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
)
- async def generate_topk_per_token(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
+ async def generate_topk_per_token(
+ self, text: str, history: Optional[List[str]] = None
+ ) -> List[Token]:
kwargs = self._pre_generate(text, history)
if self.topk_per_token > 0:
kwargs["logprobs"] = True
@@ -87,9 +104,8 @@ class OpenAIModel(TopkTokenModel):
# Limit max_tokens to 1 to avoid long completions
kwargs["max_tokens"] = 1
- completion = await self.client.chat.completions.create( # pylint: disable=E1125
- model=self.model_name,
- **kwargs
+ completion = await self.client.chat.completions.create( # pylint: disable=E1125
+ model=self.model_name, **kwargs
)
tokens = get_top_response_tokens(completion)
@@ -99,32 +115,41 @@ class OpenAIModel(TopkTokenModel):
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
)
- async def generate_answer(self, text: str, history: Optional[List[str]] = None, temperature: int = 0) -> str:
+ async def generate_answer(
+ self, text: str, history: Optional[List[str]] = None, temperature: int = 0
+ ) -> str:
kwargs = self._pre_generate(text, history)
kwargs["temperature"] = temperature
prompt_tokens = 0
- for message in kwargs['messages']:
- prompt_tokens += len(Tokenizer().encode_string(message['content']))
- estimated_tokens = prompt_tokens + kwargs['max_tokens']
+ for message in kwargs["messages"]:
+ prompt_tokens += len(
+ self.tokenizer_instance.encode_string(message["content"])
+ )
+ estimated_tokens = prompt_tokens + kwargs["max_tokens"]
if self.request_limit:
await self.rpm.wait(silent=True)
await self.tpm.wait(estimated_tokens, silent=True)
- completion = await self.client.chat.completions.create( # pylint: disable=E1125
- model=self.model_name,
- **kwargs
+ completion = await self.client.chat.completions.create( # pylint: disable=E1125
+ model=self.model_name, **kwargs
)
if hasattr(completion, "usage"):
- self.token_usage.append({
- "prompt_tokens": completion.usage.prompt_tokens,
- "completion_tokens": completion.usage.completion_tokens,
- "total_tokens": completion.usage.total_tokens,
- })
- return completion.choices[0].message.content
-
- async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
+ self.token_usage.append(
+ {
+ "prompt_tokens": completion.usage.prompt_tokens,
+ "completion_tokens": completion.usage.completion_tokens,
+ "total_tokens": completion.usage.total_tokens,
+ }
+ )
+ return filter_think_tags(completion.choices[0].message.content)
+
+ async def generate_inputs_prob(
+ self, text: str, history: Optional[List[str]] = None
+ ) -> List[Token]:
raise NotImplementedError
diff --git a/graphgen/models/search/db/__init__.py b/graphgen/models/search/db/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/models/search/db/uniprot_search.py b/graphgen/models/search/db/uniprot_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..96bdd99cb604b858741c4e269ca9e25f8221f823
--- /dev/null
+++ b/graphgen/models/search/db/uniprot_search.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+
+import requests
+from fastapi import HTTPException
+
+from graphgen.utils import logger
+
+UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
+
+
+@dataclass
+class UniProtSearch:
+ """
+ UniProt Search client to search with UniProt.
+ 1) Get the protein by accession number.
+ 2) Search with keywords or protein names.
+ """
+
+ def get_entry(self, accession: str) -> dict:
+ """
+ Get the UniProt entry by accession number(e.g., P04637).
+ """
+ url = f"{UNIPROT_BASE}/{accession}.json"
+ return self._safe_get(url).json()
+
+ def search(
+ self,
+ query: str,
+ *,
+ size: int = 10,
+ cursor: str = None,
+ fields: list[str] = None,
+ ) -> dict:
+ """
+ Search UniProt with a query string.
+ :param query: The search query.
+ :param size: The number of results to return.
+ :param cursor: The cursor for pagination.
+ :param fields: The fields to return in the response.
+ :return: A dictionary containing the search results.
+ """
+ params = {
+ "query": query,
+ "size": size,
+ }
+ if cursor:
+ params["cursor"] = cursor
+ if fields:
+ params["fields"] = ",".join(fields)
+ url = UNIPROT_BASE
+ return self._safe_get(url, params=params).json()
+
+ @staticmethod
+ def _safe_get(url: str, params: dict = None) -> requests.Response:
+ r = requests.get(
+ url,
+ params=params,
+ headers={"Accept": "application/json"},
+ timeout=10,
+ )
+ if not r.ok:
+ logger.error("Search engine error: %s", r.text)
+ raise HTTPException(r.status_code, "Search engine error.")
+ return r
diff --git a/graphgen/models/search/kg/__init__.py b/graphgen/models/search/kg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/models/search/wiki_search.py b/graphgen/models/search/kg/wiki_search.py
similarity index 87%
rename from graphgen/models/search/wiki_search.py
rename to graphgen/models/search/kg/wiki_search.py
index db312a2bfc333b485725712faa7be92fcedc43dd..e9513f21d6f8b307f736632d20cd8da472fd8925 100644
--- a/graphgen/models/search/wiki_search.py
+++ b/graphgen/models/search/kg/wiki_search.py
@@ -1,8 +1,9 @@
-from typing import List, Union
from dataclasses import dataclass
+from typing import List, Union
import wikipedia
from wikipedia import set_lang
+
from graphgen.utils import detect_main_language, logger
@@ -13,9 +14,9 @@ class WikiSearch:
assert language in ["en", "zh"], "Only support English and Chinese"
set_lang(language)
- async def search(self, query: str) -> Union[List[str], None]:
+ async def search(self, query: str, num_results: int = 1) -> Union[List[str], None]:
self.set_language(detect_main_language(query))
- return wikipedia.search(query)
+ return wikipedia.search(query, results=num_results, suggestion=False)
async def summary(self, query: str) -> Union[str, None]:
self.set_language(detect_main_language(query))
diff --git a/graphgen/models/search/web/__init__.py b/graphgen/models/search/web/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/models/search/web/bing_search.py b/graphgen/models/search/web/bing_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..a769ba7601d0a1d34eb2c32bfb40132457cea981
--- /dev/null
+++ b/graphgen/models/search/web/bing_search.py
@@ -0,0 +1,43 @@
+from dataclasses import dataclass
+
+import requests
+from fastapi import HTTPException
+
+from graphgen.utils import logger
+
+BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
+BING_MKT = "en-US"
+
+
+@dataclass
+class BingSearch:
+ """
+ Bing Search client to search with Bing.
+ """
+
+ subscription_key: str
+
+ def search(self, query: str, num_results: int = 1):
+ """
+ Search with Bing and return the contexts.
+ :param query: The search query.
+ :param num_results: The number of results to return.
+ :return: A list of search results.
+ """
+ params = {"q": query, "mkt": BING_MKT, "count": num_results}
+ response = requests.get(
+ BING_SEARCH_V7_ENDPOINT,
+ headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
+ params=params,
+ timeout=10,
+ )
+ if not response.ok:
+ logger.error("Search engine error: %s", response.text)
+ raise HTTPException(response.status_code, "Search engine error.")
+ json_content = response.json()
+ try:
+ contexts = json_content["webPages"]["value"][:num_results]
+ except KeyError:
+ logger.error("Error encountered: %s", json_content)
+ return []
+ return contexts
diff --git a/graphgen/models/search/web/google_search.py b/graphgen/models/search/web/google_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..1abfcdf35a97001a6e6f091fa9154b5728bc7c4f
--- /dev/null
+++ b/graphgen/models/search/web/google_search.py
@@ -0,0 +1,45 @@
+from dataclasses import dataclass
+
+import requests
+from fastapi import HTTPException
+
+from graphgen.utils import logger
+
+GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
+
+
+@dataclass
+class GoogleSearch:
+ def __init__(self, subscription_key: str, cx: str):
+ """
+ Initialize the Google Search client with the subscription key and custom search engine ID.
+ :param subscription_key: Your Google API subscription key.
+ :param cx: Your custom search engine ID.
+ """
+ self.subscription_key = subscription_key
+ self.cx = cx
+
+ def search(self, query: str, num_results: int = 1):
+ """
+ Search with Google and return the contexts.
+ :param query: The search query.
+ :param num_results: The number of results to return.
+ :return: A list of search results.
+ """
+ params = {
+ "key": self.subscription_key,
+ "cx": self.cx,
+ "q": query,
+ "num": num_results,
+ }
+ response = requests.get(GOOGLE_SEARCH_ENDPOINT, params=params, timeout=10)
+ if not response.ok:
+ logger.error("Search engine error: %s", response.text)
+ raise HTTPException(response.status_code, "Search engine error.")
+ json_content = response.json()
+ try:
+ contexts = json_content["items"][:num_results]
+ except KeyError:
+ logger.error("Error encountered: %s", json_content)
+ return []
+ return contexts
diff --git a/graphgen/models/storage/base_storage.py b/graphgen/models/storage/base_storage.py
index 2e70a3cb6f1bc55910e241cf127ffdbc81f44ec0..c09df074c0b1199cc03ec1fdbe1b7b297aa88537 100644
--- a/graphgen/models/storage/base_storage.py
+++ b/graphgen/models/storage/base_storage.py
@@ -1,9 +1,11 @@
from dataclasses import dataclass
-from typing import Union, Generic, TypeVar
+from typing import Generic, TypeVar, Union
+
from graphgen.models.embed.embedding import EmbeddingFunc
T = TypeVar("T")
+
@dataclass
class StorageNameSpace:
working_dir: str = None
@@ -17,9 +19,25 @@ class StorageNameSpace:
@dataclass
-class BaseKVStorage(Generic[T], StorageNameSpace):
- embedding_func: EmbeddingFunc = None
+class BaseListStorage(Generic[T], StorageNameSpace):
+ async def all_items(self) -> list[T]:
+ raise NotImplementedError
+
+ async def get_by_index(self, index: int) -> Union[T, None]:
+ raise NotImplementedError
+
+ async def append(self, data: T):
+ raise NotImplementedError
+
+ async def upsert(self, data: list[T]):
+ raise NotImplementedError
+
+ async def drop(self):
+ raise NotImplementedError
+
+@dataclass
+class BaseKVStorage(Generic[T], StorageNameSpace):
async def all_keys(self) -> list[str]:
raise NotImplementedError
@@ -41,6 +59,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
async def drop(self):
raise NotImplementedError
+
@dataclass
class BaseGraphStorage(StorageNameSpace):
embedding_func: EmbeddingFunc = None
@@ -71,7 +90,9 @@ class BaseGraphStorage(StorageNameSpace):
) -> Union[dict, None]:
raise NotImplementedError
- async def update_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
+ async def update_edge(
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
+ ):
raise NotImplementedError
async def get_all_edges(self) -> Union[list[dict], None]:
diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/json_storage.py
index e4f21e6c28cfec38c4c76303d026469b4370318e..b61572f51cdce8f1e173eba8db8fdf42a7037fff 100644
--- a/graphgen/models/storage/json_storage.py
+++ b/graphgen/models/storage/json_storage.py
@@ -1,8 +1,8 @@
import os
-
from dataclasses import dataclass
-from graphgen.utils import logger, load_json, write_json
-from graphgen.models.storage.base_storage import BaseKVStorage
+
+from graphgen.models.storage.base_storage import BaseKVStorage, BaseListStorage
+from graphgen.utils import load_json, logger, write_json
@dataclass
@@ -49,3 +49,39 @@ class JsonKVStorage(BaseKVStorage):
async def drop(self):
self._data = {}
+
+
+@dataclass
+class JsonListStorage(BaseListStorage):
+ _data: list = None
+
+ def __post_init__(self):
+ self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
+ self._data = load_json(self._file_name) or []
+ logger.info("Load List %s with %d data", self.namespace, len(self._data))
+
+ @property
+ def data(self):
+ return self._data
+
+ async def all_items(self) -> list:
+ return self._data
+
+ async def index_done_callback(self):
+ write_json(self._data, self._file_name)
+
+ async def get_by_index(self, index: int):
+ if index < 0 or index >= len(self._data):
+ return None
+ return self._data[index]
+
+ async def append(self, data):
+ self._data.append(data)
+
+ async def upsert(self, data: list):
+ left_data = [d for d in data if d not in self._data]
+ self._data.extend(left_data)
+ return left_data
+
+ async def drop(self):
+ self._data = []
diff --git a/graphgen/models/vis/__init__.py b/graphgen/models/vis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/models/vis/community_visualizer.py b/graphgen/models/vis/community_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..055510141b4ba13d475eb3017fcef15e5aa91de5
--- /dev/null
+++ b/graphgen/models/vis/community_visualizer.py
@@ -0,0 +1,48 @@
+from dataclasses import dataclass
+from typing import Dict
+
+import matplotlib.pyplot as plt
+import networkx as nx
+
+
+@dataclass
+class Visualizer:
+ """
+ Class for visualizing graphs using NetworkX and Matplotlib.
+ """
+
+ graph: nx.Graph = None
+ communities: Dict[str, int] = None
+ layout: str = "spring"
+ max_nodes: int = 1000
+ node_size: int = 10
+ alpha: float = 0.6
+
+ def visualize(self, save_path: str = None):
+ n = self.graph.number_of_nodes()
+ if self.layout == "spring":
+ k = max(0.1, 1.0 / (n**0.5))
+ pos = nx.spring_layout(self.graph, k=k, seed=42)
+ else:
+ raise ValueError(f"Unknown layout: {self.layout}")
+
+ plt.figure(figsize=(10, 10))
+
+ node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
+
+ nx.draw_networkx_nodes(
+ self.graph,
+ pos,
+ node_size=self.node_size,
+ node_color=node_colors,
+ cmap=plt.cm.tab20,
+ alpha=self.alpha,
+ )
+ nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2)
+ plt.axis("off")
+
+ if save_path:
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
+ print("Saved to", save_path)
+ else:
+ plt.show()
diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py
index 8ef14fdc5bfcbbac0d2e01eb111e643d916bbda4..f74e013a0ebbded2e2d8f94d0e97249bf4744b6c 100644
--- a/graphgen/operators/__init__.py
+++ b/graphgen/operators/__init__.py
@@ -1,16 +1,22 @@
-from .extract_kg import extract_kg
+from graphgen.operators.generate.generate_cot import generate_cot
+from graphgen.operators.kg.extract_kg import extract_kg
+from graphgen.operators.search.search_all import search_all
+
+from .judge import judge_statement
from .quiz import quiz
-from .judge import judge_statement, skip_judge_statement
-from .search_wikipedia import search_wikipedia
-from .traverse_graph import traverse_graph_by_edge, traverse_graph_atomically, traverse_graph_for_multi_hop
+from .traverse_graph import (
+ traverse_graph_atomically,
+ traverse_graph_by_edge,
+ traverse_graph_for_multi_hop,
+)
__all__ = [
"extract_kg",
"quiz",
"judge_statement",
- "skip_judge_statement",
- "search_wikipedia",
+ "search_all",
"traverse_graph_by_edge",
"traverse_graph_atomically",
- "traverse_graph_for_multi_hop"
+ "traverse_graph_for_multi_hop",
+ "generate_cot",
]
diff --git a/graphgen/operators/generate/__init__.py b/graphgen/operators/generate/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/operators/generate/generate_cot.py b/graphgen/operators/generate/generate_cot.py
new file mode 100644
index 0000000000000000000000000000000000000000..b87bce2bbd2c6eb0588c867454f3d6b386cc5e98
--- /dev/null
+++ b/graphgen/operators/generate/generate_cot.py
@@ -0,0 +1,117 @@
+import asyncio
+from typing import Dict, List, Tuple
+
+from tqdm.asyncio import tqdm as tqdm_async
+
+from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIModel
+from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
+from graphgen.utils import compute_content_hash, detect_main_language
+
+
+async def generate_cot(
+ graph_storage: NetworkXStorage,
+ synthesizer_llm_client: OpenAIModel,
+ method_params: Dict = None,
+):
+ method = method_params.get("method", "leiden")
+ detector = CommunityDetector(
+ graph_storage=graph_storage, method=method, method_params=method_params
+ )
+
+ results = await detector.detect_communities()
+
+ # Convert results to a format suitable for summarization
+ communities = {}
+ for node, community_id in results.items():
+ if community_id not in communities:
+ communities[community_id] = []
+ communities[community_id].append(node)
+
+ if not communities:
+ return {}
+
+ semaphore = asyncio.Semaphore(value=1000)
+
+ async def _generate_from_single_community(
+ c_id: int, nodes: List[str]
+ ) -> Tuple[int, Tuple[str, str, str]]:
+ """Summarize a single community."""
+ async with semaphore:
+ entities: List[str] = []
+ relationships: List[str] = []
+
+ for n in nodes:
+ node_data = await graph_storage.get_node(n)
+ if node_data is not None:
+ entities.append(f"({n}: {node_data.get('description')})")
+
+ edges = await graph_storage.get_node_edges(n)
+ for edge in edges:
+ target = edge[1]
+ if target in nodes:
+ edge_data = await graph_storage.get_edge(n, target)
+ relationships.append(
+ f"({n}) - [{edge_data['description']}] -> ({target})"
+ )
+
+ entities_str = "\n".join(entities)
+ relationships_str = "\n".join(relationships)
+
+ language = (
+ "English"
+ if detect_main_language(entities_str + relationships_str) == "en"
+ else "Chinese"
+ )
+
+ prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format(
+ entities=entities_str,
+ relationships=relationships_str,
+ )
+
+ cot_template = await synthesizer_llm_client.generate_answer(prompt)
+
+ if "问题:" in cot_template and "推理路径设计:" in cot_template:
+ question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip()
+ reasoning_path = cot_template.split("推理路径设计:")[1].strip()
+ elif (
+ "Question:" in cot_template and "Reasoning-Path Design:" in cot_template
+ ):
+ question = (
+ cot_template.split("Question:")[1]
+ .split("Reasoning-Path Design:")[0]
+ .strip()
+ )
+ reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip()
+ else:
+ raise ValueError("COT template format is incorrect.")
+
+ prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format(
+ entities=entities_str,
+ relationships=relationships_str,
+ question=question,
+ reasoning_template=reasoning_path,
+ )
+
+ cot_answer = await synthesizer_llm_client.generate_answer(prompt)
+
+ return c_id, (question, reasoning_path, cot_answer)
+
+ cid_nodes = list(communities.items())
+
+ results: Dict = {}
+ async for coro in tqdm_async(
+ asyncio.as_completed(
+ [_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes]
+ ),
+ total=len(cid_nodes),
+ desc="[Generating COT] Generating CoT data from communities",
+ unit="community",
+ ):
+ cid, (q, r, a) = await coro
+ results[compute_content_hash(q)] = {
+ "question": q,
+ "reasoning_path": r,
+ "answer": a,
+ }
+
+ return results
diff --git a/graphgen/operators/judge.py b/graphgen/operators/judge.py
index 0292e1e40819a85b191bffd32ac622bc6811ddf0..61e9d33ebdd88936d06fdb69d08e52611a1fb647 100644
--- a/graphgen/operators/judge.py
+++ b/graphgen/operators/judge.py
@@ -1,17 +1,20 @@
-import math
import asyncio
+import math
+
from tqdm.asyncio import tqdm as tqdm_async
-from graphgen.models import NetworkXStorage, OpenAIModel, JsonKVStorage
-from graphgen.utils import logger, yes_no_loss_entropy
+
+from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIModel
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
+from graphgen.utils import logger, yes_no_loss_entropy
-async def judge_statement( # pylint: disable=too-many-statements
- trainee_llm_client: OpenAIModel,
- graph_storage: NetworkXStorage,
- rephrase_storage: JsonKVStorage,
- re_judge: bool = False,
- max_concurrent: int = 1000) -> NetworkXStorage:
+async def judge_statement( # pylint: disable=too-many-statements
+ trainee_llm_client: OpenAIModel,
+ graph_storage: NetworkXStorage,
+ rephrase_storage: JsonKVStorage,
+ re_judge: bool = False,
+ max_concurrent: int = 1000,
+) -> NetworkXStorage:
"""
Get all edges and nodes and judge them
@@ -34,7 +37,12 @@ async def judge_statement( # pylint: disable=too-many-statements
edge_data = edge[2]
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
- logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"])
+ logger.info(
+ "Edge %s -> %s already judged, loss: %s, skip",
+ source_id,
+ target_id,
+ edge_data["loss"],
+ )
return source_id, target_id, edge_data
description = edge_data["description"]
@@ -47,17 +55,27 @@ async def judge_statement( # pylint: disable=too-many-statements
gts = [gt for _, gt in descriptions]
for description, gt in descriptions:
judgement = await trainee_llm_client.generate_topk_per_token(
- STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
+ STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
+ statement=description
+ )
)
judgements.append(judgement[0].top_candidates)
loss = yes_no_loss_entropy(judgements, gts)
- logger.info("Edge %s -> %s description: %s loss: %s", source_id, target_id, description, loss)
+ logger.info(
+ "Edge %s -> %s description: %s loss: %s",
+ source_id,
+ target_id,
+ description,
+ loss,
+ )
edge_data["loss"] = loss
- except Exception as e: # pylint: disable=broad-except
- logger.error("Error in judging relation %s -> %s: %s", source_id, target_id, e)
+ except Exception as e: # pylint: disable=broad-except
+ logger.error(
+ "Error in judging relation %s -> %s: %s", source_id, target_id, e
+ )
logger.info("Use default loss 0.1")
edge_data["loss"] = -math.log(0.1)
@@ -68,9 +86,9 @@ async def judge_statement( # pylint: disable=too-many-statements
results = []
for result in tqdm_async(
- asyncio.as_completed([_judge_single_relation(edge) for edge in edges]),
- total=len(edges),
- desc="Judging relations"
+ asyncio.as_completed([_judge_single_relation(edge) for edge in edges]),
+ total=len(edges),
+ desc="Judging relations",
):
results.append(await result)
@@ -82,7 +100,9 @@ async def judge_statement( # pylint: disable=too-many-statements
node_data = node[1]
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
- logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"])
+ logger.info(
+ "Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
+ )
return node_id, node_data
description = node_data["description"]
@@ -95,16 +115,20 @@ async def judge_statement( # pylint: disable=too-many-statements
gts = [gt for _, gt in descriptions]
for description, gt in descriptions:
judgement = await trainee_llm_client.generate_topk_per_token(
- STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
+ STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
+ statement=description
+ )
)
judgements.append(judgement[0].top_candidates)
loss = yes_no_loss_entropy(judgements, gts)
- logger.info("Node %s description: %s loss: %s", node_id, description, loss)
+ logger.info(
+ "Node %s description: %s loss: %s", node_id, description, loss
+ )
node_data["loss"] = loss
- except Exception as e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
logger.error("Error in judging entity %s: %s", node_id, e)
logger.info("Use default loss 0.1")
node_data["loss"] = -math.log(0.1)
@@ -116,72 +140,9 @@ async def judge_statement( # pylint: disable=too-many-statements
results = []
for result in tqdm_async(
- asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
- total=len(nodes),
- desc="Judging entities"
- ):
- results.append(await result)
-
- return graph_storage
-
-async def skip_judge_statement(
- graph_storage: NetworkXStorage,
- max_concurrent: int = 1000
-):
- """
- Skip the judgement of the statement
- :param graph_storage: graph storage instance
- :param max_concurrent: max concurrent
- :return:
- """
- semaphore = asyncio.Semaphore(max_concurrent)
-
- async def _skip_single_relation(
- edge: tuple,
- ):
- async with semaphore:
- source_id = edge[0]
- target_id = edge[1]
- edge_data = edge[2]
-
- if "loss" in edge_data and edge_data["loss"] is not None:
- logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"])
- return source_id, target_id, edge_data
-
- edge_data["loss"] = -math.log(0.1)
- await graph_storage.update_edge(source_id, target_id, edge_data)
- return source_id, target_id, edge_data
-
- edges = await graph_storage.get_all_edges()
- results = []
- for result in tqdm_async(
- asyncio.as_completed([_skip_single_relation(edge) for edge in edges]),
- total=len(edges),
- desc="Skipping judgement of relations"
- ):
- results.append(await result)
-
- async def _skip_single_entity(
- node: tuple,
- ):
- async with semaphore:
- node_id = node[0]
- node_data = node[1]
-
- if "loss" in node_data and node_data["loss"] is not None:
- logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"])
- return node_id, node_data
-
- node_data["loss"] = -math.log(0.1)
- await graph_storage.update_node(node_id, node_data)
- return node_id, node_data
-
- nodes = await graph_storage.get_all_nodes()
- results = []
- for result in tqdm_async(
- asyncio.as_completed([_skip_single_entity(node) for node in nodes]),
- total=len(nodes),
- desc="Skipping judgement of entities"
+ asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
+ total=len(nodes),
+ desc="Judging entities",
):
results.append(await result)
diff --git a/graphgen/operators/kg/__init__.py b/graphgen/operators/kg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/operators/extract_kg.py b/graphgen/operators/kg/extract_kg.py
similarity index 67%
rename from graphgen/operators/extract_kg.py
rename to graphgen/operators/kg/extract_kg.py
index 3fad55254730639eddc93f913ba77ad8b4cdf470..406e400b379ce613175f2515ba8448a651e80fd5 100644
--- a/graphgen/operators/extract_kg.py
+++ b/graphgen/operators/kg/extract_kg.py
@@ -1,27 +1,33 @@
-import re
import asyncio
-from typing import List
+import re
from collections import defaultdict
+from typing import List
import gradio as gr
from tqdm.asyncio import tqdm as tqdm_async
+
from graphgen.models import Chunk, OpenAIModel, Tokenizer
from graphgen.models.storage.base_storage import BaseGraphStorage
+from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
from graphgen.templates import KG_EXTRACTION_PROMPT
-from graphgen.utils import (logger, pack_history_conversations, split_string_by_multi_markers,
- handle_single_entity_extraction, handle_single_relationship_extraction,
- detect_if_chinese)
-from graphgen.operators.merge_kg import merge_nodes, merge_edges
+from graphgen.utils import (
+ detect_if_chinese,
+ handle_single_entity_extraction,
+ handle_single_relationship_extraction,
+ logger,
+ pack_history_conversations,
+ split_string_by_multi_markers,
+)
# pylint: disable=too-many-statements
async def extract_kg(
- llm_client: OpenAIModel,
- kg_instance: BaseGraphStorage,
- tokenizer_instance: Tokenizer,
- chunks: List[Chunk],
- progress_bar: gr.Progress = None,
- max_concurrent: int = 1000
+ llm_client: OpenAIModel,
+ kg_instance: BaseGraphStorage,
+ tokenizer_instance: Tokenizer,
+ chunks: List[Chunk],
+ progress_bar: gr.Progress = None,
+ max_concurrent: int = 1000,
):
"""
:param llm_client: Synthesizer LLM model to extract entities and relationships
@@ -50,25 +56,25 @@ async def extract_kg(
)
final_result = await llm_client.generate_answer(hint_prompt)
- logger.info('First result: %s', final_result)
+ logger.info("First result: %s", final_result)
history = pack_history_conversations(hint_prompt, final_result)
for loop_index in range(max_loop):
if_loop_result = await llm_client.generate_answer(
- text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"],
- history=history
+ text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
glean_result = await llm_client.generate_answer(
- text=KG_EXTRACTION_PROMPT[language]["CONTINUE"],
- history=history
+ text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
)
- logger.info('Loop %s glean: %s', loop_index, glean_result)
+ logger.info("Loop %s glean: %s", loop_index, glean_result)
- history += pack_history_conversations(KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result)
+ history += pack_history_conversations(
+ KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
+ )
final_result += glean_result
if loop_index == max_loop - 1:
break
@@ -76,8 +82,9 @@ async def extract_kg(
records = split_string_by_multi_markers(
final_result,
[
- KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
- KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"]],
+ KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
+ KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
+ ],
)
nodes = defaultdict(list)
@@ -87,16 +94,20 @@ async def extract_kg(
record = re.search(r"\((.*)\)", record)
if record is None:
continue
- record = record.group(1) # 提取括号内的内容
+ record = record.group(1) # 提取括号内的内容
record_attributes = split_string_by_multi_markers(
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
)
- entity = await handle_single_entity_extraction(record_attributes, chunk_id)
+ entity = await handle_single_entity_extraction(
+ record_attributes, chunk_id
+ )
if entity is not None:
nodes[entity["entity_name"]].append(entity)
continue
- relation = await handle_single_relationship_extraction(record_attributes, chunk_id)
+ relation = await handle_single_relationship_extraction(
+ record_attributes, chunk_id
+ )
if relation is not None:
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
return dict(nodes), dict(edges)
@@ -106,17 +117,25 @@ async def extract_kg(
async for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in chunks]),
total=len(chunks),
- desc="[3/4]Extracting entities and relationships from chunks",
+ desc="[2/4]Extracting entities and relationships from chunks",
unit="chunk",
):
try:
if progress_bar is not None:
- progress_bar(len(results) / chunk_number, desc="[3/4]Extracting entities and relationships from chunks")
+ progress_bar(
+ len(results) / chunk_number,
+ desc="[3/4]Extracting entities and relationships from chunks",
+ )
results.append(await result)
if progress_bar is not None and len(results) == chunk_number:
- progress_bar(1, desc="[3/4]Extracting entities and relationships from chunks")
- except Exception as e: # pylint: disable=broad-except
- logger.error("Error occurred while extracting entities and relationships from chunks: %s", e)
+ progress_bar(
+ 1, desc="[3/4]Extracting entities and relationships from chunks"
+ )
+ except Exception as e: # pylint: disable=broad-except
+ logger.error(
+ "Error occurred while extracting entities and relationships from chunks: %s",
+ e,
+ )
nodes = defaultdict(list)
edges = defaultdict(list)
diff --git a/graphgen/operators/merge_kg.py b/graphgen/operators/kg/merge_kg.py
similarity index 76%
rename from graphgen/operators/merge_kg.py
rename to graphgen/operators/kg/merge_kg.py
index 33aa1395b26f4ed0d5754cfde0de5967ee296b2a..30379e66cb1ccbedcdf5ae2d8fa4c318ff9671af 100644
--- a/graphgen/operators/merge_kg.py
+++ b/graphgen/operators/kg/merge_kg.py
@@ -1,19 +1,21 @@
-from collections import Counter
import asyncio
+from collections import Counter
+
from tqdm.asyncio import tqdm as tqdm_async
-from graphgen.utils.format import split_string_by_multi_markers
-from graphgen.utils import logger, detect_main_language
-from graphgen.models import TopkTokenModel, Tokenizer
+from graphgen.models import Tokenizer, TopkTokenModel
from graphgen.models.storage.base_storage import BaseGraphStorage
-from graphgen.templates import KG_SUMMARIZATION_PROMPT, KG_EXTRACTION_PROMPT
+from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
+from graphgen.utils import detect_main_language, logger
+from graphgen.utils.format import split_string_by_multi_markers
+
async def _handle_kg_summary(
entity_or_relation_name: str,
description: str,
llm_client: TopkTokenModel,
tokenizer_instance: Tokenizer,
- max_summary_tokens: int = 200
+ max_summary_tokens: int = 200,
) -> str:
"""
处理实体或关系的描述信息
@@ -33,17 +35,19 @@ async def _handle_kg_summary(
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
tokens = tokenizer_instance.encode_string(description)
- if len(tokens) < max_summary_tokens:
+ if len(tokens) < max_summary_tokens:
return description
use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens])
prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
entity_name=entity_or_relation_name,
- description_list=use_description.split(''),
- **KG_SUMMARIZATION_PROMPT["FORMAT"]
+ description_list=use_description.split(""),
+ **KG_SUMMARIZATION_PROMPT["FORMAT"],
)
new_description = await llm_client.generate_answer(prompt)
- logger.info("Entity or relation %s summary: %s", entity_or_relation_name, new_description)
+ logger.info(
+ "Entity or relation %s summary: %s", entity_or_relation_name, new_description
+ )
return new_description
@@ -52,7 +56,7 @@ async def merge_nodes(
kg_instance: BaseGraphStorage,
llm_client: TopkTokenModel,
tokenizer_instance: Tokenizer,
- max_concurrent: int = 1000
+ max_concurrent: int = 1000,
):
"""
Merge nodes
@@ -77,39 +81,34 @@ async def merge_nodes(
if node is not None:
entity_types.append(node["entity_type"])
source_ids.extend(
- split_string_by_multi_markers(node["source_id"], [''])
+ split_string_by_multi_markers(node["source_id"], [""])
)
descriptions.append(node["description"])
# 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
entity_type = sorted(
- Counter(
- [dp["entity_type"] for dp in node_data] + entity_types
- ).items(),
+ Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
- description = ''.join(
+ description = "".join(
sorted(set([dp["description"] for dp in node_data] + descriptions))
)
description = await _handle_kg_summary(
entity_name, description, llm_client, tokenizer_instance
)
- source_id = ''.join(
+ source_id = "".join(
set([dp["source_id"] for dp in node_data] + source_ids)
)
node_data = {
"entity_type": entity_type,
"description": description,
- "source_id": source_id
+ "source_id": source_id,
}
- await kg_instance.upsert_node(
- entity_name,
- node_data=node_data
- )
+ await kg_instance.upsert_node(entity_name, node_data=node_data)
node_data["entity_name"] = entity_name
return node_data
@@ -125,7 +124,7 @@ async def merge_nodes(
):
try:
entities_data.append(await result)
- except Exception as e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while inserting entities into storage: %s", e)
@@ -134,7 +133,7 @@ async def merge_edges(
kg_instance: BaseGraphStorage,
llm_client: TopkTokenModel,
tokenizer_instance: Tokenizer,
- max_concurrent: int = 1000
+ max_concurrent: int = 1000,
):
"""
Merge edges
@@ -157,14 +156,14 @@ async def merge_edges(
edge = await kg_instance.get_edge(src_id, tgt_id)
if edge is not None:
source_ids.extend(
- split_string_by_multi_markers(edge["source_id"], [''])
+ split_string_by_multi_markers(edge["source_id"], [""])
)
descriptions.append(edge["description"])
- description = ''.join(
+ description = "".join(
sorted(set([dp["description"] for dp in edge_data] + descriptions))
)
- source_id = ''.join(
+ source_id = "".join(
set([dp["source_id"] for dp in edge_data] + source_ids)
)
@@ -175,8 +174,8 @@ async def merge_edges(
node_data={
"source_id": source_id,
"description": description,
- "entity_type": "UNKNOWN"
- }
+ "entity_type": "UNKNOWN",
+ },
)
description = await _handle_kg_summary(
@@ -186,24 +185,20 @@ async def merge_edges(
await kg_instance.upsert_edge(
src_id,
tgt_id,
- edge_data={
- "source_id": source_id,
- "description": description
- }
+ edge_data={"source_id": source_id, "description": description},
)
- edge_data = {
- "src_id": src_id,
- "tgt_id": tgt_id,
- "description": description
- }
+ edge_data = {"src_id": src_id, "tgt_id": tgt_id, "description": description}
return edge_data
logger.info("Inserting relationships into storage...")
relationships_data = []
for result in tqdm_async(
asyncio.as_completed(
- [process_single_edge(src_id, tgt_id, v) for (src_id, tgt_id), v in edges_data.items()]
+ [
+ process_single_edge(src_id, tgt_id, v)
+ for (src_id, tgt_id), v in edges_data.items()
+ ]
),
total=len(edges_data),
desc="Inserting relationships into storage",
@@ -211,5 +206,7 @@ async def merge_edges(
):
try:
relationships_data.append(await result)
- except Exception as e: # pylint: disable=broad-except
- logger.error("Error occurred while inserting relationships into storage: %s", e)
+ except Exception as e: # pylint: disable=broad-except
+ logger.error(
+ "Error occurred while inserting relationships into storage: %s", e
+ )
diff --git a/graphgen/operators/split_graph.py b/graphgen/operators/kg/split_kg.py
similarity index 74%
rename from graphgen/operators/split_graph.py
rename to graphgen/operators/kg/split_kg.py
index e2e2b5cab36f0c7f5e1193712792a58755293456..a3307a86c76db799af326d55979eccbeca29f751 100644
--- a/graphgen/operators/split_graph.py
+++ b/graphgen/operators/kg/split_kg.py
@@ -1,14 +1,16 @@
import random
from collections import defaultdict
+
from tqdm.asyncio import tqdm as tqdm_async
-from graphgen.utils import logger
from graphgen.models import NetworkXStorage, TraverseStrategy
+from graphgen.utils import logger
+
async def _get_node_info(
node_id: str,
graph_storage: NetworkXStorage,
-)-> dict:
+) -> dict:
"""
Get node info
@@ -17,10 +19,7 @@ async def _get_node_info(
:return: node info
"""
node_data = await graph_storage.get_node(node_id)
- return {
- "node_id": node_id,
- **node_data
- }
+ return {"node_id": node_id, **node_data}
def _get_level_n_edges_by_max_width(
@@ -33,7 +32,7 @@ def _get_level_n_edges_by_max_width(
bidirectional: bool,
max_extra_edges: int,
edge_sampling: str,
- loss_strategy: str = "only_edge"
+ loss_strategy: str = "only_edge",
) -> list:
"""
Get level n edges for an edge.
@@ -71,10 +70,17 @@ def _get_level_n_edges_by_max_width(
if len(candidate_edges) >= max_extra_edges:
if loss_strategy == "both":
- er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
- candidate_edges = _sort_tuples(er_tuples, edge_sampling)[:max_extra_edges]
+ er_tuples = [
+ ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
+ for edge in candidate_edges
+ ]
+ candidate_edges = _sort_tuples(er_tuples, edge_sampling)[
+ :max_extra_edges
+ ]
elif loss_strategy == "only_edge":
- candidate_edges = _sort_edges(candidate_edges, edge_sampling)[:max_extra_edges]
+ candidate_edges = _sort_edges(candidate_edges, edge_sampling)[
+ :max_extra_edges
+ ]
else:
raise ValueError(f"Invalid loss strategy: {loss_strategy}")
@@ -101,16 +107,16 @@ def _get_level_n_edges_by_max_width(
def _get_level_n_edges_by_max_tokens(
- edge_adj_list: dict,
- node_dict: dict,
- edges: list,
- nodes: list,
- src_edge: tuple,
- max_depth: int,
- bidirectional: bool,
- max_tokens: int,
- edge_sampling: str,
- loss_strategy: str = "only_edge"
+ edge_adj_list: dict,
+ node_dict: dict,
+ edges: list,
+ nodes: list,
+ src_edge: tuple,
+ max_depth: int,
+ bidirectional: bool,
+ max_tokens: int,
+ edge_sampling: str,
+ loss_strategy: str = "only_edge",
) -> list:
"""
Get level n edges for an edge.
@@ -129,8 +135,11 @@ def _get_level_n_edges_by_max_tokens(
"""
src_id, tgt_id, src_edge_data = src_edge
- max_tokens -= (src_edge_data["length"] + nodes[node_dict[src_id]][1]["length"]
- + nodes[node_dict[tgt_id]][1]["length"])
+ max_tokens -= (
+ src_edge_data["length"]
+ + nodes[node_dict[src_id]][1]["length"]
+ + nodes[node_dict[tgt_id]][1]["length"]
+ )
level_n_edges = []
@@ -151,7 +160,10 @@ def _get_level_n_edges_by_max_tokens(
break
if loss_strategy == "both":
- er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
+ er_tuples = [
+ ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
+ for edge in candidate_edges
+ ]
candidate_edges = _sort_tuples(er_tuples, edge_sampling)
elif loss_strategy == "only_edge":
candidate_edges = _sort_edges(candidate_edges, edge_sampling)
@@ -196,15 +208,22 @@ def _sort_tuples(er_tuples: list, edge_sampling: str) -> list:
if edge_sampling == "random":
er_tuples = random.sample(er_tuples, len(er_tuples))
elif edge_sampling == "min_loss":
- er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"])
+ er_tuples = sorted(
+ er_tuples,
+ key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
+ )
elif edge_sampling == "max_loss":
- er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
- reverse=True)
+ er_tuples = sorted(
+ er_tuples,
+ key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
+ reverse=True,
+ )
else:
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
edges = [edge for _, edge in er_tuples]
return edges
+
def _sort_edges(edges: list, edge_sampling: str) -> list:
"""
Sort edges with edge sampling strategy
@@ -223,11 +242,12 @@ def _sort_edges(edges: list, edge_sampling: str) -> list:
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
return edges
-async def get_batches_with_strategy( # pylint: disable=too-many-branches
+
+async def get_batches_with_strategy( # pylint: disable=too-many-branches
nodes: list,
edges: list,
graph_storage: NetworkXStorage,
- traverse_strategy: TraverseStrategy
+ traverse_strategy: TraverseStrategy,
):
expand_method = traverse_strategy.expand_method
if expand_method == "max_width":
@@ -256,7 +276,10 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
node_dict[node_name] = i
if traverse_strategy.loss_strategy == "both":
- er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in edges]
+ er_tuples = [
+ ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
+ for edge in edges
+ ]
edges = _sort_tuples(er_tuples, edge_sampling)
elif traverse_strategy.loss_strategy == "only_edge":
edges = _sort_edges(edges, edge_sampling)
@@ -279,21 +302,36 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
src_id = edge[0]
tgt_id = edge[1]
- _process_nodes.extend([await get_cached_node_info(src_id),
- await get_cached_node_info(tgt_id)])
+ _process_nodes.extend(
+ [await get_cached_node_info(src_id), await get_cached_node_info(tgt_id)]
+ )
_process_edges.append(edge)
if expand_method == "max_width":
level_n_edges = _get_level_n_edges_by_max_width(
- edge_adj_list, node_dict, edges, nodes, edge, max_depth,
- traverse_strategy.bidirectional, traverse_strategy.max_extra_edges,
- edge_sampling, traverse_strategy.loss_strategy
+ edge_adj_list,
+ node_dict,
+ edges,
+ nodes,
+ edge,
+ max_depth,
+ traverse_strategy.bidirectional,
+ traverse_strategy.max_extra_edges,
+ edge_sampling,
+ traverse_strategy.loss_strategy,
)
else:
level_n_edges = _get_level_n_edges_by_max_tokens(
- edge_adj_list, node_dict, edges, nodes, edge, max_depth,
- traverse_strategy.bidirectional, traverse_strategy.max_tokens,
- edge_sampling, traverse_strategy.loss_strategy
+ edge_adj_list,
+ node_dict,
+ edges,
+ nodes,
+ edge,
+ max_depth,
+ traverse_strategy.bidirectional,
+ traverse_strategy.max_tokens,
+ edge_sampling,
+ traverse_strategy.loss_strategy,
)
for _edge in level_n_edges:
@@ -302,8 +340,12 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
_process_edges.append(_edge)
# 去重
- _process_nodes = list({node['node_id']: node for node in _process_nodes}.values())
- _process_edges = list({(edge[0], edge[1]): edge for edge in _process_edges}.values())
+ _process_nodes = list(
+ {node["node_id"]: node for node in _process_nodes}.values()
+ )
+ _process_edges = list(
+ {(edge[0], edge[1]): edge for edge in _process_edges}.values()
+ )
processing_batches.append((_process_nodes, _process_edges))
@@ -312,15 +354,21 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
# isolate nodes
isolated_node_strategy = traverse_strategy.isolated_node_strategy
if isolated_node_strategy == "add":
- processing_batches = await _add_isolated_nodes(nodes, processing_batches, graph_storage)
- logger.info("Processing batches after adding isolated nodes: %d", len(processing_batches))
+ processing_batches = await _add_isolated_nodes(
+ nodes, processing_batches, graph_storage
+ )
+ logger.info(
+ "Processing batches after adding isolated nodes: %d",
+ len(processing_batches),
+ )
return processing_batches
+
async def _add_isolated_nodes(
- nodes: list,
- processing_batches: list,
- graph_storage: NetworkXStorage,
+ nodes: list,
+ processing_batches: list,
+ graph_storage: NetworkXStorage,
) -> list:
visited_nodes = set()
for _process_nodes, _process_edges in processing_batches:
diff --git a/graphgen/operators/preprocess/__init__.py b/graphgen/operators/preprocess/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/operators/resolute_coreference.py b/graphgen/operators/preprocess/resolute_coreference.py
similarity index 60%
rename from graphgen/operators/resolute_coreference.py
rename to graphgen/operators/preprocess/resolute_coreference.py
index 4a1012fb55aa8d9aee0e1cd36cf4eed55f25fa8d..cdf702e23bb773266719790e957fbc8dd33ac637 100644
--- a/graphgen/operators/resolute_coreference.py
+++ b/graphgen/operators/preprocess/resolute_coreference.py
@@ -1,12 +1,13 @@
from typing import List
-from graphgen.models import Chunk
-from graphgen.models import OpenAIModel
-from graphgen.templates import COREFERENCE_RESOLUTION_TEMPLATE
+
+from graphgen.models import Chunk, OpenAIModel
+from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
from graphgen.utils import detect_main_language
+
async def resolute_coreference(
- llm_client: OpenAIModel,
- chunks: List[Chunk]) -> List[Chunk]:
+ llm_client: OpenAIModel, chunks: List[Chunk]
+) -> List[Chunk]:
"""
Resolute conference
@@ -23,9 +24,8 @@ async def resolute_coreference(
for _, chunk in enumerate(chunks[1:]):
language = detect_main_language(chunk.content)
result = await llm_client.generate_answer(
- COREFERENCE_RESOLUTION_TEMPLATE[language].format(
- reference = results[0].content,
- input_sentence = chunk.content
+ COREFERENCE_RESOLUTION_PROMPT[language].format(
+ reference=results[0].content, input_sentence=chunk.content
)
)
results.append(Chunk(id=chunk.id, content=result))
diff --git a/graphgen/operators/search/__init__.py b/graphgen/operators/search/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/operators/search/db/__init__.py b/graphgen/operators/search/db/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/operators/search/db/search_uniprot.py b/graphgen/operators/search/db/search_uniprot.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/operators/search/kg/__init__.py b/graphgen/operators/search/kg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/operators/search/kg/search_wikipedia.py b/graphgen/operators/search/kg/search_wikipedia.py
new file mode 100644
index 0000000000000000000000000000000000000000..05449fe1318fbd6cac61101da0a82812af68efec
--- /dev/null
+++ b/graphgen/operators/search/kg/search_wikipedia.py
@@ -0,0 +1,58 @@
+from tqdm.asyncio import tqdm_asyncio as tqdm_async
+
+from graphgen.models import WikiSearch
+from graphgen.utils import logger
+
+
+async def _process_single_entity(
+ entity_name: str,
+ wiki_search_client: WikiSearch,
+) -> str | None:
+ """
+ Process single entity by searching Wikipedia
+ :param entity_name
+ :param wiki_search_client
+ :return: summary of the entity or None if not found
+ """
+ search_results = await wiki_search_client.search(entity_name)
+ if not search_results:
+ return None
+
+ summary = None
+ try:
+ summary = await wiki_search_client.summary(search_results[-1])
+ logger.info(
+ "Entity %s search result: %s summary: %s",
+ entity_name,
+ str(search_results),
+ summary,
+ )
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error processing entity %s: %s", entity_name, str(e))
+
+ return summary
+
+
+async def search_wikipedia(
+ wiki_search_client: WikiSearch,
+ entities: set[str],
+) -> dict:
+ """
+ Search wikipedia for entities
+
+ :param wiki_search_client: wiki search client
+ :param entities: list of entities to search
+ :return: nodes with search results
+ """
+ wiki_data = {}
+
+ async for entity in tqdm_async(
+ entities, desc="Searching Wikipedia", total=len(entities)
+ ):
+ try:
+ summary = await _process_single_entity(entity, wiki_search_client)
+ if summary:
+ wiki_data[entity] = summary
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error processing entity %s: %s", entity, str(e))
+ return wiki_data
diff --git a/graphgen/operators/search/search_all.py b/graphgen/operators/search/search_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7ecbea14d94944c043c122ac0397d81135deee4
--- /dev/null
+++ b/graphgen/operators/search/search_all.py
@@ -0,0 +1,82 @@
+"""
+To use Google Web Search API,
+follow the instructions [here](https://developers.google.com/custom-search/v1/overview)
+to get your Google search api key.
+
+To use Bing Web Search API,
+follow the instructions [here](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api)
+and obtain your Bing subscription key.
+"""
+
+import os
+
+from graphgen.utils import logger
+
+
+async def search_all(
+ search_types: dict, search_entities: set[str]
+) -> dict[str, dict[str, str]]:
+ """
+ :param search_types
+ :param search_entities: list of entities to search
+ :return: nodes with search results
+ """
+
+ results = {}
+
+ for search_type in search_types:
+ if search_type == "wikipedia":
+ from graphgen.models import WikiSearch
+ from graphgen.operators.search.kg.search_wikipedia import search_wikipedia
+
+ wiki_search_client = WikiSearch()
+
+ wiki_results = await search_wikipedia(wiki_search_client, search_entities)
+ for entity_name, description in wiki_results.items():
+ if description:
+ results[entity_name] = {"wikipedia": description}
+ elif search_type == "google":
+ from graphgen.models import GoogleSearch
+ from graphgen.operators.search.web.search_google import search_google
+
+ google_search_client = GoogleSearch(
+ subscription_key=os.environ["GOOGLE_SEARCH_API_KEY"],
+ cx=os.environ["GOOGLE_SEARCH_CX"],
+ )
+
+ google_results = await search_google(google_search_client, search_entities)
+ for entity_name, description in google_results.items():
+ if description:
+ results[entity_name] = results.get(entity_name, {})
+ results[entity_name]["google"] = description
+ elif search_type == "bing":
+ from graphgen.models import BingSearch
+ from graphgen.operators.search.web.search_bing import search_bing
+
+ bing_search_client = BingSearch(
+ subscription_key=os.environ["BING_SEARCH_API_KEY"]
+ )
+
+ bing_results = await search_bing(bing_search_client, search_entities)
+ for entity_name, description in bing_results.items():
+ if description:
+ results[entity_name] = results.get(entity_name, {})
+ results[entity_name]["bing"] = description
+ elif search_type == "uniprot":
+ # from graphgen.models import UniProtSearch
+ # from graphgen.operators.search.db.search_uniprot import search_uniprot
+ #
+ # uniprot_search_client = UniProtSearch()
+ #
+ # uniprot_results = await search_uniprot(
+ # uniprot_search_client, search_entities
+ # )
+ raise NotImplementedError(
+ "Processing of UniProt search results is not implemented yet."
+ )
+
+ else:
+ logger.error("Search type %s is not supported yet.", search_type)
+ continue
+
+ return results
diff --git a/graphgen/operators/search/web/__init__.py b/graphgen/operators/search/web/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graphgen/operators/search/web/search_bing.py b/graphgen/operators/search/web/search_bing.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f65f7b528a83202f46aa006c312a5aea07d2c6
--- /dev/null
+++ b/graphgen/operators/search/web/search_bing.py
@@ -0,0 +1,53 @@
+import trafilatura
+from tqdm.asyncio import tqdm_asyncio as tqdm_async
+
+from graphgen.models import BingSearch
+from graphgen.utils import logger
+
+
+async def _process_single_entity(
+ entity_name: str, bing_search_client: BingSearch
+) -> str | None:
+ """
+ Process single entity by searching Bing.
+ :param entity_name: The name of the entity to search.
+ :param bing_search_client: The Bing search client.
+ :return: Summary of the entity or None if not found.
+ """
+ search_results = bing_search_client.search(entity_name)
+ if not search_results:
+ return None
+
+ # Get more details from the first search result
+ first_result = search_results[0]
+ content = trafilatura.fetch_url(first_result["url"])
+ summary = trafilatura.extract(content, include_comments=False, include_links=False)
+ summary = summary.strip()
+ logger.info(
+ "Entity %s search result: %s",
+ entity_name,
+ summary,
+ )
+ return summary
+
+
+async def search_bing(
+ bing_search_client: BingSearch,
+ entities: set[str],
+) -> dict[str, str]:
+ """
+ Search with Bing and return the contexts.
+ :return:
+ """
+ bing_data = {}
+
+ async for entity in tqdm_async(
+ entities, desc="Searching Bing", total=len(entities)
+ ):
+ try:
+ summary = await _process_single_entity(entity, bing_search_client)
+ if summary:
+ bing_data[entity] = summary
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error processing entity %s: %s", entity, str(e))
+ return bing_data
diff --git a/graphgen/operators/search/web/search_google.py b/graphgen/operators/search/web/search_google.py
new file mode 100644
index 0000000000000000000000000000000000000000..803ce107a39583b65946c24894dead5679b3f973
--- /dev/null
+++ b/graphgen/operators/search/web/search_google.py
@@ -0,0 +1,49 @@
+import trafilatura
+from tqdm.asyncio import tqdm_asyncio as tqdm_async
+
+from graphgen.models import GoogleSearch
+from graphgen.utils import logger
+
+
+async def _process_single_entity(
+ entity_name: str, google_search_client: GoogleSearch
+) -> str | None:
+ search_results = google_search_client.search(entity_name)
+ if not search_results:
+ return None
+
+ # Get more details from the first search result
+ first_result = search_results[0]
+ content = trafilatura.fetch_url(first_result["link"])
+ summary = trafilatura.extract(content, include_comments=False, include_links=False)
+ summary = summary.strip()
+ logger.info(
+ "Entity %s search result: %s",
+ entity_name,
+ summary,
+ )
+ return summary
+
+
+async def search_google(
+ google_search_client: GoogleSearch,
+ entities: set[str],
+) -> dict:
+ """
+ Search with Google and return the contexts.
+ :param google_search_client: Google search client
+ :param entities: list of entities to search
+ :return:
+ """
+ google_data = {}
+
+ async for entity in tqdm_async(
+ entities, desc="Searching Google", total=len(entities)
+ ):
+ try:
+ summary = await _process_single_entity(entity, google_search_client)
+ if summary:
+ google_data[entity] = summary
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error processing entity %s: %s", entity, str(e))
+ return google_data
diff --git a/graphgen/operators/search_wikipedia.py b/graphgen/operators/search_wikipedia.py
deleted file mode 100644
index d3d7e28314eebd3ba72cb731290a069c96fe7e97..0000000000000000000000000000000000000000
--- a/graphgen/operators/search_wikipedia.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import asyncio
-from graphgen.models import WikiSearch, OpenAIModel
-from graphgen.models.storage.base_storage import BaseGraphStorage
-from graphgen.templates import SEARCH_JUDGEMENT_PROMPT
-from graphgen.utils import logger
-
-
-async def _process_single_entity(entity_name: str,
- description: str,
- llm_client: OpenAIModel,
- wiki_search_client: WikiSearch) -> tuple[str, None] | tuple[str, str]:
- """
- Process single entity
-
- """
- search_results = await wiki_search_client.search(entity_name)
- if not search_results:
- return entity_name, None
- examples = "\n".join(SEARCH_JUDGEMENT_PROMPT["EXAMPLES"])
- search_results.append("None of the above")
-
- search_results_str = "\n".join([f"{i + 1}. {sr}" for i, sr in enumerate(search_results)])
- prompt = SEARCH_JUDGEMENT_PROMPT["TEMPLATE"].format(
- examples=examples,
- entity_name=entity_name,
- description=description,
- search_results=search_results_str,
- )
- response = await llm_client.generate_answer(prompt)
-
- try:
- response = response.strip()
- response = int(response)
- if response < 1 or response >= len(search_results):
- response = None
- else:
- response = await wiki_search_client.summary(search_results[response - 1])
- except ValueError:
- response = None
-
- logger.info("Entity %s search result: %s response: %s", entity_name, str(search_results), response)
-
- return entity_name, response
-
-async def search_wikipedia(llm_client: OpenAIModel,
- wiki_search_client: WikiSearch,
- knowledge_graph_instance: BaseGraphStorage,) -> dict:
- """
- Search wikipedia for entities
-
- :param llm_client: LLM model
- :param wiki_search_client: wiki search client
- :param knowledge_graph_instance: knowledge graph instance
- :return: nodes with search results
- """
-
-
- nodes = await knowledge_graph_instance.get_all_nodes()
- nodes = list(nodes)
- wiki_data = {}
-
- tasks = [
- _process_single_entity(node[0].strip('"'), node[1]["description"], llm_client, wiki_search_client)
- for node in nodes
- ]
-
- for task in asyncio.as_completed(tasks):
- result = await task
- wiki_data[result[0]] = result[1]
-
- return wiki_data
diff --git a/graphgen/operators/traverse_graph.py b/graphgen/operators/traverse_graph.py
index 947033ed233311d75f349afcc40a3661c0a09bb8..da1b668544806ee2faffb49c97fd72d90a74935f 100644
--- a/graphgen/operators/traverse_graph.py
+++ b/graphgen/operators/traverse_graph.py
@@ -1,49 +1,67 @@
import asyncio
-import gradio as gr
+import gradio as gr
from tqdm.asyncio import tqdm as tqdm_async
-from graphgen.models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer, JsonKVStorage
-from graphgen.templates import ANSWER_REPHRASING_PROMPT, QUESTION_GENERATION_PROMPT, MULTI_HOP_GENERATION_PROMPT
-from graphgen.utils import detect_main_language, compute_content_hash, logger
-from graphgen.operators.split_graph import get_batches_with_strategy
-
-
-async def _pre_tokenize(graph_storage: NetworkXStorage,
- tokenizer: Tokenizer,
- edges: list,
- nodes: list) -> tuple:
+from graphgen.models import (
+ JsonKVStorage,
+ NetworkXStorage,
+ OpenAIModel,
+ Tokenizer,
+ TraverseStrategy,
+)
+from graphgen.operators.kg.split_kg import get_batches_with_strategy
+from graphgen.templates import (
+ ANSWER_REPHRASING_PROMPT,
+ MULTI_HOP_GENERATION_PROMPT,
+ QUESTION_GENERATION_PROMPT,
+)
+from graphgen.utils import compute_content_hash, detect_main_language, logger
+
+
+async def _pre_tokenize(
+ graph_storage: NetworkXStorage, tokenizer: Tokenizer, edges: list, nodes: list
+) -> tuple:
sem = asyncio.Semaphore(1000)
+
async def handle_edge(edge: tuple) -> tuple:
async with sem:
- if 'length' not in edge[2]:
- edge[2]['length'] = len(
- await asyncio.get_event_loop().run_in_executor(None,
- tokenizer.encode_string,
- edge[2]['description']))
+ if "length" not in edge[2]:
+ edge[2]["length"] = len(
+ await asyncio.get_event_loop().run_in_executor(
+ None, tokenizer.encode_string, edge[2]["description"]
+ )
+ )
return edge
async def handle_node(node: dict) -> dict:
async with sem:
- if 'length' not in node[1]:
- node[1]['length'] = len(
- await asyncio.get_event_loop().run_in_executor(None,
- tokenizer.encode_string,
- node[1]['description']))
+ if "length" not in node[1]:
+ node[1]["length"] = len(
+ await asyncio.get_event_loop().run_in_executor(
+ None, tokenizer.encode_string, node[1]["description"]
+ )
+ )
return node
new_edges = []
new_nodes = []
- for result in tqdm_async(asyncio.as_completed([handle_edge(edge) for edge in edges]),
- total=len(edges), desc="Pre-tokenizing edges"):
+ for result in tqdm_async(
+ asyncio.as_completed([handle_edge(edge) for edge in edges]),
+ total=len(edges),
+ desc="Pre-tokenizing edges",
+ ):
new_edge = await result
await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2])
new_edges.append(new_edge)
- for result in tqdm_async(asyncio.as_completed([handle_node(node) for node in nodes]),
- total=len(nodes), desc="Pre-tokenizing nodes"):
+ for result in tqdm_async(
+ asyncio.as_completed([handle_node(node) for node in nodes]),
+ total=len(nodes),
+ desc="Pre-tokenizing nodes",
+ ):
new_node = await result
await graph_storage.update_node(new_node[0], new_node[1])
new_nodes.append(new_node)
@@ -51,60 +69,75 @@ async def _pre_tokenize(graph_storage: NetworkXStorage,
await graph_storage.index_done_callback()
return new_edges, new_nodes
-async def _construct_rephrasing_prompt(_process_nodes: list,
- _process_edges: list,
- text_chunks_storage: JsonKVStorage,
- add_context: bool = False
- ) -> str:
+
+async def _construct_rephrasing_prompt(
+ _process_nodes: list,
+ _process_edges: list,
+ text_chunks_storage: JsonKVStorage,
+ add_context: bool = False,
+) -> str:
entities = [
- f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
+ f"{_process_node['node_id']}: {_process_node['description']}"
+ for _process_node in _process_nodes
]
relations = [
f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
for _process_edge in _process_edges
]
- entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])
- relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])
- language = "Chinese" if detect_main_language(entities_str + relations_str) == "zh" else "English"
+ entities_str = "\n".join(
+ [f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
+ )
+ relations_str = "\n".join(
+ [f"{index + 1}. {relation}" for index, relation in enumerate(relations)]
+ )
+ language = (
+ "Chinese"
+ if detect_main_language(entities_str + relations_str) == "zh"
+ else "English"
+ )
if add_context:
- original_ids = ([node['source_id'].split('')[0] for node in _process_nodes] +
- [edge[2]['source_id'].split('')[0] for edge in _process_edges])
+ original_ids = [
+ node["source_id"].split("")[0] for node in _process_nodes
+ ] + [edge[2]["source_id"].split("")[0] for edge in _process_edges]
original_ids = list(set(original_ids))
original_text = await text_chunks_storage.get_by_ids(original_ids)
- original_text = "\n".join([f"{index + 1}. {text['content']}" for index, text in enumerate(original_text)])
+ original_text = "\n".join(
+ [
+ f"{index + 1}. {text['content']}"
+ for index, text in enumerate(original_text)
+ ]
+ )
- prompt = ANSWER_REPHRASING_PROMPT[language]['CONTEXT_TEMPLATE'].format(
+ prompt = ANSWER_REPHRASING_PROMPT[language]["CONTEXT_TEMPLATE"].format(
language=language,
original_text=original_text,
entities=entities_str,
- relationships=relations_str
+ relationships=relations_str,
)
return prompt
- prompt = ANSWER_REPHRASING_PROMPT[language]['TEMPLATE'].format(
- language=language,
- entities=entities_str,
- relationships=relations_str
+ prompt = ANSWER_REPHRASING_PROMPT[language]["TEMPLATE"].format(
+ language=language, entities=entities_str, relationships=relations_str
)
return prompt
-def get_loss_tercile(losses: list) -> (float, float):
- losses = sorted(losses)
- q1_index = int(len(losses) * (1 / 3))
- q2_index = int(len(losses) * (2 / 3))
-
- return losses[q1_index], losses[q2_index]
def get_average_loss(batch: tuple, loss_strategy: str) -> float:
- if loss_strategy == "only_edge":
- return sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
- if loss_strategy == "both":
- return sum(edge[2]['loss'] for edge in batch[1]) + sum(node['loss'] for node in batch[0]) / \
- (len(batch[0]) + len(batch[1]))
- raise ValueError("Invalid loss strategy")
+ try:
+ if loss_strategy == "only_edge":
+ return sum(edge[2]["loss"] for edge in batch[1]) / len(batch[1])
+ if loss_strategy == "both":
+ return sum(edge[2]["loss"] for edge in batch[1]) + sum(
+ node["loss"] for node in batch[0]
+ ) / (len(batch[0]) + len(batch[1]))
+ raise ValueError("Invalid loss strategy")
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error calculating average loss: %s", e)
+ return -1.0
+
def _post_process_synthetic_data(data):
block = data.split("\n\n")
@@ -113,26 +146,18 @@ def _post_process_synthetic_data(data):
if "Question:" in line and "Answer:" in line:
question = line.split("Question:")[1].split("Answer:")[0].strip()
answer = line.split("Answer:")[1].strip()
- qas.append({
- "question": question,
- "answer": answer
- })
+ qas.append({"question": question, "answer": answer})
elif "问题:" in line and "答案:" in line:
question = line.split("问题:")[1].split("答案:")[0].strip()
answer = line.split("答案:")[1].strip()
- qas.append({
- "question": question,
- "answer": answer
- })
+ qas.append({"question": question, "answer": answer})
elif "问题:" in line and "回答:" in line:
question = line.split("问题:")[1].split("回答:")[0].strip()
answer = line.split("回答:")[1].strip()
- qas.append({
- "question": question,
- "answer": answer
- })
+ qas.append({"question": question, "answer": answer})
return qas
+
async def traverse_graph_by_edge(
llm_client: OpenAIModel,
tokenizer: Tokenizer,
@@ -140,7 +165,7 @@ async def traverse_graph_by_edge(
traverse_strategy: TraverseStrategy,
text_chunks_storage: JsonKVStorage,
progress_bar: gr.Progress = None,
- max_concurrent: int = 1000
+ max_concurrent: int = 1000,
) -> dict:
"""
Traverse the graph
@@ -158,28 +183,24 @@ async def traverse_graph_by_edge(
semaphore = asyncio.Semaphore(max_concurrent)
async def _process_nodes_and_edges(
- _process_nodes: list,
- _process_edges: list,
+ _process_nodes: list,
+ _process_edges: list,
) -> str:
prompt = await _construct_rephrasing_prompt(
- _process_nodes,
- _process_edges,
- text_chunks_storage,
- add_context = False
+ _process_nodes, _process_edges, text_chunks_storage, add_context=False
)
context = await llm_client.generate_answer(prompt)
# post-process the context
if context.startswith("Rephrased Text:"):
- context = context[len("Rephrased Text:"):].strip()
+ context = context[len("Rephrased Text:") :].strip()
elif context.startswith("重述文本:"):
- context = context[len("重述文本:"):].strip()
+ context = context[len("重述文本:") :].strip()
return context
async def _process_single_batch(
- _process_batch: tuple,
- question_type: str = "single"
+ _process_batch: tuple, question_type: str = "single"
) -> dict:
async with semaphore:
context = await _process_nodes_and_edges(
@@ -188,21 +209,26 @@ async def traverse_graph_by_edge(
)
language = "Chinese" if detect_main_language(context) == "zh" else "English"
- pre_length = sum(node['length'] for node in _process_batch[0]) \
- + sum(edge[2]['length'] for edge in _process_batch[1])
+ pre_length = sum(node["length"] for node in _process_batch[0]) + sum(
+ edge[2]["length"] for edge in _process_batch[1]
+ )
if question_type == "single":
question = await llm_client.generate_answer(
- QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
+ QUESTION_GENERATION_PROMPT[language]["SINGLE_TEMPLATE"].format(
answer=context
)
)
if question.startswith("Question:"):
- question = question[len("Question:"):].strip()
+ question = question[len("Question:") :].strip()
elif question.startswith("问题:"):
- question = question[len("问题:"):].strip()
+ question = question[len("问题:") :].strip()
- logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
+ logger.info(
+ "%d nodes and %d edges processed",
+ len(_process_batch[0]),
+ len(_process_batch[1]),
+ )
logger.info("Pre-length: %s", pre_length)
logger.info("Question: %s", question)
logger.info("Answer: %s", context)
@@ -211,12 +237,14 @@ async def traverse_graph_by_edge(
compute_content_hash(context): {
"question": question,
"answer": context,
- "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy)
+ "loss": get_average_loss(
+ _process_batch, traverse_strategy.loss_strategy
+ ),
}
}
content = await llm_client.generate_answer(
- QUESTION_GENERATION_PROMPT[language]['MULTI_TEMPLATE'].format(
+ QUESTION_GENERATION_PROMPT[language]["MULTI_TEMPLATE"].format(
doc=context
)
)
@@ -224,19 +252,27 @@ async def traverse_graph_by_edge(
if len(qas) == 0:
print(content)
- logger.error("Error occurred while processing batch, question or answer is None")
+ logger.error(
+ "Error occurred while processing batch, question or answer is None"
+ )
return {}
final_results = {}
- logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
+ logger.info(
+ "%d nodes and %d edges processed",
+ len(_process_batch[0]),
+ len(_process_batch[1]),
+ )
logger.info("Pre-length: %s", pre_length)
for qa in qas:
- logger.info("Question: %s", qa['question'])
- logger.info("Answer: %s", qa['answer'])
- final_results[compute_content_hash(qa['question'])] = {
- "question": qa['question'],
- "answer": qa['answer'],
- "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy)
+ logger.info("Question: %s", qa["question"])
+ logger.info("Answer: %s", qa["answer"])
+ final_results[compute_content_hash(qa["question"])] = {
+ "question": qa["question"],
+ "answer": qa["answer"],
+ "loss": get_average_loss(
+ _process_batch, traverse_strategy.loss_strategy
+ ),
}
return final_results
@@ -247,22 +283,25 @@ async def traverse_graph_by_edge(
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
processing_batches = await get_batches_with_strategy(
- nodes,
- edges,
- graph_storage,
- traverse_strategy
+ nodes, edges, graph_storage, traverse_strategy
)
- for result in tqdm_async(asyncio.as_completed(
- [_process_single_batch(batch) for batch in processing_batches]
- ), total=len(processing_batches), desc="[4/4]Generating QAs"):
+ for result in tqdm_async(
+ asyncio.as_completed(
+ [_process_single_batch(batch) for batch in processing_batches]
+ ),
+ total=len(processing_batches),
+ desc="[4/4]Generating QAs",
+ ):
try:
if progress_bar is not None:
- progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
+ progress_bar(
+ len(results) / len(processing_batches), desc="[4/4]Generating QAs"
+ )
results.update(await result)
if progress_bar is not None and len(results) == len(processing_batches):
progress_bar(1, desc="[4/4]Generating QAs")
- except Exception as e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while generating QA: %s", e)
return results
@@ -275,7 +314,7 @@ async def traverse_graph_atomically(
traverse_strategy: TraverseStrategy,
text_chunks_storage: JsonKVStorage,
progress_bar: gr.Progress = None,
- max_concurrent: int = 1000
+ max_concurrent: int = 1000,
) -> dict:
"""
Traverse the graph atomicly
@@ -292,22 +331,21 @@ async def traverse_graph_atomically(
assert traverse_strategy.qa_form == "atomic"
semaphore = asyncio.Semaphore(max_concurrent)
- async def _generate_question(
- node_or_edge: tuple
- ):
+
+ async def _generate_question(node_or_edge: tuple):
if len(node_or_edge) == 2:
- des = node_or_edge[0] + ": " + node_or_edge[1]['description']
- loss = node_or_edge[1]['loss']
+ des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
+ loss = node_or_edge[1]["loss"]
else:
- des = node_or_edge[2]['description']
- loss = node_or_edge[2]['loss']
+ des = node_or_edge[2]["description"]
+ loss = node_or_edge[2]["loss"]
async with semaphore:
try:
language = "Chinese" if detect_main_language(des) == "zh" else "English"
qa = await llm_client.generate_answer(
- QUESTION_GENERATION_PROMPT[language]['SINGLE_QA_TEMPLATE'].format(
+ QUESTION_GENERATION_PROMPT[language]["SINGLE_QA_TEMPLATE"].format(
doc=des
)
)
@@ -321,8 +359,8 @@ async def traverse_graph_atomically(
else:
return {}
- question = question.strip("\"")
- answer = answer.strip("\"")
+ question = question.strip('"')
+ answer = answer.strip('"')
logger.info("Question: %s", question)
logger.info("Answer: %s", answer)
@@ -330,10 +368,10 @@ async def traverse_graph_atomically(
compute_content_hash(question): {
"question": question,
"answer": answer,
- "loss": loss
+ "loss": loss,
}
}
- except Exception as e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while generating question: %s", e)
return {}
@@ -345,24 +383,26 @@ async def traverse_graph_atomically(
tasks = []
for node in nodes:
- if "" in node[1]['description']:
- description_list = node[1]['description'].split("")
+ if "" in node[1]["description"]:
+ description_list = node[1]["description"].split("")
for item in description_list:
- tasks.append((node[0], {"description": item, 'loss': node[1]['loss']}))
+ tasks.append((node[0], {"description": item, "loss": node[1]["loss"]}))
else:
tasks.append((node[0], node[1]))
for edge in edges:
- if "" in edge[2]['description']:
- description_list = edge[2]['description'].split("")
+ if "" in edge[2]["description"]:
+ description_list = edge[2]["description"].split("")
for item in description_list:
- tasks.append((edge[0], edge[1], {"description": item, 'loss': edge[2]['loss']}))
+ tasks.append(
+ (edge[0], edge[1], {"description": item, "loss": edge[2]["loss"]})
+ )
else:
tasks.append((edge[0], edge[1], edge[2]))
for result in tqdm_async(
asyncio.as_completed([_generate_question(task) for task in tasks]),
total=len(tasks),
- desc="[4/4]Generating QAs"
+ desc="[4/4]Generating QAs",
):
try:
if progress_bar is not None:
@@ -370,10 +410,11 @@ async def traverse_graph_atomically(
results.update(await result)
if progress_bar is not None and len(results) == len(tasks):
progress_bar(1, desc="[4/4]Generating QAs")
- except Exception as e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while generating QA: %s", e)
return results
+
async def traverse_graph_for_multi_hop(
llm_client: OpenAIModel,
tokenizer: Tokenizer,
@@ -381,7 +422,7 @@ async def traverse_graph_for_multi_hop(
traverse_strategy: TraverseStrategy,
text_chunks_storage: JsonKVStorage,
progress_bar: gr.Progress = None,
- max_concurrent: int = 1000
+ max_concurrent: int = 1000,
) -> dict:
"""
Traverse the graph for multi-hop
@@ -395,8 +436,6 @@ async def traverse_graph_for_multi_hop(
:param max_concurrent
:return: question and answer
"""
- assert traverse_strategy.qa_form == "multi_hop"
-
semaphore = asyncio.Semaphore(max_concurrent)
results = {}
@@ -406,24 +445,24 @@ async def traverse_graph_for_multi_hop(
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
processing_batches = await get_batches_with_strategy(
- nodes,
- edges,
- graph_storage,
- traverse_strategy
+ nodes, edges, graph_storage, traverse_strategy
)
- async def _process_single_batch(
- _process_batch: tuple
- ) -> dict:
+ async def _process_single_batch(_process_batch: tuple) -> dict:
async with semaphore:
try:
- language = "Chinese" if detect_main_language(_process_batch[0][0]['description']) == "zh" else "English"
+ language = (
+ "Chinese"
+ if detect_main_language(_process_batch[0][0]["description"]) == "zh"
+ else "English"
+ )
_process_nodes = _process_batch[0]
_process_edges = _process_batch[1]
entities = [
- f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
+ f"{_process_node['node_id']}: {_process_node['description']}"
+ for _process_node in _process_nodes
]
relations = [
@@ -431,12 +470,18 @@ async def traverse_graph_for_multi_hop(
for _process_edge in _process_edges
]
- entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])
- relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])
+ entities_str = "\n".join(
+ [f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
+ )
+ relations_str = "\n".join(
+ [
+ f"{index + 1}. {relation}"
+ for index, relation in enumerate(relations)
+ ]
+ )
prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
- entities=entities_str,
- relationships=relations_str
+ entities=entities_str, relationships=relations_str
)
context = await llm_client.generate_answer(prompt)
@@ -451,8 +496,8 @@ async def traverse_graph_for_multi_hop(
else:
return {}
- question = question.strip("\"")
- answer = answer.strip("\"")
+ question = question.strip('"')
+ answer = answer.strip('"')
logger.info("Question: %s", question)
logger.info("Answer: %s", answer)
@@ -461,25 +506,31 @@ async def traverse_graph_for_multi_hop(
compute_content_hash(question): {
"question": question,
"answer": answer,
- "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
+ "loss": get_average_loss(
+ _process_batch, traverse_strategy.loss_strategy
+ ),
}
}
- except Exception as e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while processing batch: %s", e)
return {}
async for result in tqdm_async(
- asyncio.as_completed([_process_single_batch(batch) for batch in processing_batches]),
+ asyncio.as_completed(
+ [_process_single_batch(batch) for batch in processing_batches]
+ ),
total=len(processing_batches),
- desc="[4/4]Generating QAs"
+ desc="[4/4]Generating QAs",
):
try:
if progress_bar is not None:
- progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
+ progress_bar(
+ len(results) / len(processing_batches), desc="[4/4]Generating QAs"
+ )
results.update(await result)
if progress_bar is not None and len(results) == len(processing_batches):
progress_bar(1, desc="[4/4]Generating QAs")
- except Exception as e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while generating QA: %s", e)
return results
diff --git a/graphgen/templates/__init__.py b/graphgen/templates/__init__.py
index 6e362d081792fa6beaf74f3ea1ca8fa846a3bf1b..a3d1e9ed5dfd20f0f08cb6c39f40bb1794b80ca4 100644
--- a/graphgen/templates/__init__.py
+++ b/graphgen/templates/__init__.py
@@ -1,9 +1,10 @@
+from .answer_rephrasing import ANSWER_REPHRASING_PROMPT
+from .community import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
+from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
+from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
from .kg_extraction import KG_EXTRACTION_PROMPT
from .kg_summarization import KG_SUMMARIZATION_PROMPT
+from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT
+from .question_generation import QUESTION_GENERATION_PROMPT
from .search_judgement import SEARCH_JUDGEMENT_PROMPT
-from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
-from .answer_rephrasing import ANSWER_REPHRASING_PROMPT
-from .question_generation import QUESTION_GENERATION_PROMPT
-from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT
-from .coreference_resolution import COREFERENCE_RESOLUTION_TEMPLATE
diff --git a/graphgen/templates/answer_rephrasing.py b/graphgen/templates/answer_rephrasing.py
index a33e9d9e0b0e2278fb3dec1a7fc144784b995887..fc988fa25edeedca98674268c9403c50f2ebb995 100644
--- a/graphgen/templates/answer_rephrasing.py
+++ b/graphgen/templates/answer_rephrasing.py
@@ -1,5 +1,4 @@
TEMPLATE_CONTEXT_EN: str = """---Role---
-
You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below. You may refer to the original text to assist in generating the rephrased version, but ensure that the final output text meets the requirements.
Use {language} as output language.
@@ -51,12 +50,10 @@ To generate a version of the text that is rephrased and conveys the same meaning
"""
TEMPLATE_CONTEXT_ZH: str = """---角色---
-
你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。你可以参考原始文本辅助生成,但需要确保最终输出的文本符合要求。
使用{language}作为输出语言。
---目标---
-
生成文本的重述版本,使其传达与原始实体和关系描述相同的含义,同时:
1. 遵循清晰的逻辑流和结构
2. 建立适当的因果关系
@@ -101,7 +98,6 @@ TEMPLATE_CONTEXT_ZH: str = """---角色---
"""
TEMPLATE_EN: str = """---Role---
-
You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below.
Use {language} as output language.
@@ -148,12 +144,10 @@ To generate a version of the text that is rephrased and conveys the same meaning
"""
TEMPLATE_ZH: str = """---角色---
-
你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。
使用{language}作为输出语言。
---目标---
-
生成文本的重述版本,使其传达与原始实体和关系描述相同的含义,同时:
1. 遵循清晰的逻辑流和结构
2. 建立适当的因果关系
@@ -207,13 +201,13 @@ Rephrased Text:
"""
-ANSWER_REPHRASING_PROMPT= {
+ANSWER_REPHRASING_PROMPT = {
"English": {
"TEMPLATE": TEMPLATE_EN + REQUIREMENT_EN,
- "CONTEXT_TEMPLATE": TEMPLATE_CONTEXT_EN + REQUIREMENT_EN
+ "CONTEXT_TEMPLATE": TEMPLATE_CONTEXT_EN + REQUIREMENT_EN,
},
"Chinese": {
"TEMPLATE": TEMPLATE_ZH + REQUIREMENT_ZH,
- "CONTEXT_TEMPLATE": TEMPLATE_CONTEXT_ZH + REQUIREMENT_ZH
- }
+ "CONTEXT_TEMPLATE": TEMPLATE_CONTEXT_ZH + REQUIREMENT_ZH,
+ },
}
diff --git a/graphgen/templates/community/__init__.py b/graphgen/templates/community/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4721d03e285fe2ef43778e808cf2481c86ddb78e
--- /dev/null
+++ b/graphgen/templates/community/__init__.py
@@ -0,0 +1,2 @@
+from .cot_generation import COT_GENERATION_PROMPT
+from .cot_template_design import COT_TEMPLATE_DESIGN_PROMPT
diff --git a/graphgen/templates/community/cot_generation.py b/graphgen/templates/community/cot_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..0494cd805d605d5ebdf5dd24b06bd90111c2a292
--- /dev/null
+++ b/graphgen/templates/community/cot_generation.py
@@ -0,0 +1,87 @@
+TEMPLATE_ZH = """根据给定的知识图谱原始信息及已生成的推理路径,产出一条符合模板要求、可直接用于下游训练或推理的 CoT 数据。\
+CoT(Chain-of-Thought,思维链)指在回答复杂问题时,把中间推理步骤一步一步显式写出来,使推理过程透明、可追溯,而不是直接给出最终答案。
+
+-输入格式-
+[Entities:]
+(实体名:实体描述)
+...
+
+[Relationships:]
+(来源实体)-[关系描述]->(目标实体)
+...
+
+[Question and Reasoning Path:]
+(问题)
+(推理路径)
+
+-输出要求-
+1. 每一步只完成一个不可分割的子任务,并用自然语言衔接,但是要避免生硬的连接词。
+2. 使用中文。
+3. 不要使用有序列表或编号。
+4. 请直接给出答案,不要生成无关信息。
+
+-真实数据-
+输入:
+[Entities:]:
+{entities}
+
+[Relationships:]:
+{relationships}
+
+[Question:]:
+{question}
+
+[Reasoning_Template:]:
+{reasoning_template}
+
+输出:
+
+"""
+
+TEMPLATE_EN = """Given the raw knowledge graph information and the provided reasoning-path, \
+produce one Chain-of-Thought (CoT) sample that strictly follows the template \
+and can be directly used for downstream training or inference.
+CoT (Chain-of-Thought) means that when answering a complex question, the intermediate reasoning steps are \
+explicitly written out one by one, making the reasoning process transparent and traceable instead of giving \
+only the final answer.
+
+-Input Format-
+[Entities:]:
+(ENTITY_NAME: ENTITY_DESCRIPTION)
+...
+
+[Relationships:]:
+(ENTITY_SOURCE)-[RELATIONSHIP_DESCRIPTION]->(ENTITY_TARGET)
+...
+
+[Question and Reasoning Path:]:
+(QUESTION)
+(REASONING_PATH)
+
+-Output Requirements-
+1. Each step completes a single, indivisible sub-task and is naturally connected, avoiding abrupt transition words.
+2. Use English.
+3. Do not use ordered lists or numbering.
+4. Do not generate extraneous information, just provide the answer.
+
+-Real Data-
+Input:
+[Entities:]:
+{entities}
+
+[Relationships:]:
+{relationships}
+
+[Question:]:
+{question}
+
+[Reasoning_Template:]:
+{reasoning_template}
+
+Output:
+"""
+
+COT_GENERATION_PROMPT = {
+ "Chinese": {"TEMPLATE": TEMPLATE_ZH},
+ "English": {"TEMPLATE": TEMPLATE_EN},
+}
diff --git a/graphgen/templates/community/cot_template_design.py b/graphgen/templates/community/cot_template_design.py
new file mode 100644
index 0000000000000000000000000000000000000000..04cfa2309c7f035124477084734a38c1b9a6a5d1
--- /dev/null
+++ b/graphgen/templates/community/cot_template_design.py
@@ -0,0 +1,107 @@
+TEMPLATE_ZH = """你是一位“元推理架构师”。你的任务不是回答问题,\
+而是根据给定的知识图谱中的实体和关系的名称以及描述信息,设计一条可复用、可泛化的 CoT 推理路径模板。\
+
+-步骤-
+1. 实体识别
+- 准确地识别[Entities:]章节中的实体信息,包括实体名、实体描述信息。
+- 实体信息的一般格式为:
+(实体名:实体描述)
+
+2. 关系识别
+- 准确地识别[Relationships:]章节中的关系信息,包括来源实体名、目标实体名、关系描述信息。
+- 关系信息的一般格式为:
+(来源实体名)-[关系描述]->(目标实体名)
+
+3. 图结构理解
+- 正确地将关系信息中的来源实体名与实体信息关联。
+- 根据提供的关系信息还原出图结构。
+
+4. 问题设计
+- 围绕知识图谱所表达的“核心主题”设计一个问题。
+- 问题必须能在图谱内部通过实体、关系或属性直接验证;避免主观判断。
+- 问题应该能够模型足够的思考,充分利用图谱中的实体和关系,避免过于简单或无关的问题。
+
+5. 推理路径生成
+- 根据问题设计一个**可被后续模型直接执行的推理蓝图**。
+- 保持步骤最小化:每一步只解决一个“不可分割”的子问题。
+
+-约束条件-
+1. 不要在回答中描述你的思考过程,直接给出回复,只给出问题和推理路径设计,不要生成无关信息。
+2. 如果提供的描述信息相互矛盾,请解决矛盾并提供一个单一、连贯的逻辑。
+3. 避免使用停用词和过于常见的词汇。
+4. 不要出现具体数值或结论,不要出现“识别实体”、“识别关系”这类无意义的操作描述。
+5. 使用中文作为输出语言。
+6. 输出格式为:
+问题:
+推理路径设计:
+
+-真实数据-
+输入:
+[Entities:]:
+{entities}
+
+[Relationships:]:
+{relationships}
+
+输出:
+"""
+
+
+TEMPLATE_EN = """You are a “meta-reasoning architect”. \
+Your task is NOT to answer the question, but to design a reusable, generalizable CoT reasoning-path \
+template based solely on the names and descriptions of entities and \
+relationships in the provided knowledge graph.
+
+- Steps -
+1. Entity Recognition
+- Accurately recognize entity information in the [Entities:] section, including entity names and descriptions.
+- The general formats for entity information are:
+(ENTITY_NAME: ENTITY_DESCRIPTION)
+
+2. Relationship Recognition
+- Accurately recognize relationship information in the [Relationships:] section, including source_entity_name, target_entity_name, and relationship descriptions.
+- The general formats for relationship information are:
+(SOURCE_ENTITY_NAME)-[RELATIONSHIP_DESCRIPTION]->(TARGET_ENTITY_NAME)
+
+3. Graph Structure Understanding
+- Correctly associate the source entity name in the relationship information with the entity information.
+- Reconstruct the graph structure based on the provided relationship information.
+
+4. Question Design
+- Design a question around the "core theme" expressed by the knowledge graph.
+- The question must be verifiable directly within the graph through entities, relationships, or attributes; avoid subjective judgments.
+- The question should allow the model to think sufficiently, fully utilizing the entities and relationships in the graph, avoiding overly simple or irrelevant questions.
+
+5. Reasoning-Path Design
+- Output a **blueprint that any later model can directly execute**.
+- Keep steps minimal: each step solves one indivisible sub-problem.
+
+
+- Constraints -
+1. Do NOT describe your thinking; output only the reasoning-path design.
+2. If the provided descriptions are contradictory, resolve conflicts and provide a single coherent logic.
+3. Avoid using stop words and overly common words.
+4. Do not include specific numerical values or conclusions, \
+and DO NOT describing meaningless operations like "Identify the entity" or "Identify the relationship".
+5. Use English as the output language.
+6. The output format is:
+Question:
+Reasoning-Path Design:
+
+Please summarize the information expressed by the knowledge graph based on the following [Entities:] and [Relationships:] provided.
+
+- Real Data -
+Input:
+[Entities:]:
+{entities}
+
+[Relationships:]:
+{relationships}
+
+Output:
+"""
+
+COT_TEMPLATE_DESIGN_PROMPT = {
+ "Chinese": {"TEMPLATE": TEMPLATE_ZH},
+ "English": {"TEMPLATE": TEMPLATE_EN},
+}
diff --git a/graphgen/templates/coreference_resolution.py b/graphgen/templates/coreference_resolution.py
index b29394ad6b6916fa95e6ce35e4df20cb6981c801..bc03e671411b6c1abd9b75ea9c1dbb976d7d331e 100644
--- a/graphgen/templates/coreference_resolution.py
+++ b/graphgen/templates/coreference_resolution.py
@@ -1,4 +1,3 @@
-# pylint: disable=C0301
TEMPLATE_ZH: str = """请根据参考文本识别并消解文本中的指代词,明确每个代词所指代的具体实体,并直接输出消解后的文本。
-示例-
@@ -16,7 +15,8 @@ TEMPLATE_ZH: str = """请根据参考文本识别并消解文本中的指代词
输出:
"""
-TEMPLATE_EN: str = """Please identify and resolve the pronouns in the reference text, specify the specific entities referred to by each pronoun, and directly output the resolved text.
+TEMPLATE_EN: str = """Please identify and resolve the pronouns in the reference text, \
+specify the specific entities referred to by each pronoun, and directly output the resolved text.
-Example-
Input:
@@ -33,7 +33,4 @@ Please directly output the rewritten sentence without any additional information
Output:
"""
-COREFERENCE_RESOLUTION_TEMPLATE = {
- "en": TEMPLATE_EN,
- "zh": TEMPLATE_ZH
-}
+COREFERENCE_RESOLUTION_PROMPT = {"en": TEMPLATE_EN, "zh": TEMPLATE_ZH}
diff --git a/graphgen/templates/search_judgement.py b/graphgen/templates/search_judgement.py
index ca9e7e12fae76e27743edb4c753b60484aafd5d5..e85b00974990959b07db5d2e44fdc209f5435667 100644
--- a/graphgen/templates/search_judgement.py
+++ b/graphgen/templates/search_judgement.py
@@ -17,7 +17,7 @@ Steps:
################
-Examples-
################
-{examples}
+{input_examples}
################
-Real Data-
diff --git a/graphgen/utils/__init__.py b/graphgen/utils/__init__.py
index 932f8df1c1ca1444bb0225060e2dadbe690223ce..b3c8e1e646661739ae64e8accedb8f515a4b220c 100644
--- a/graphgen/utils/__init__.py
+++ b/graphgen/utils/__init__.py
@@ -1,9 +1,16 @@
-from .log import logger, set_logger, parse_log
-from .loop import create_event_loop
-from .format import (pack_history_conversations, split_string_by_multi_markers,
- handle_single_entity_extraction, handle_single_relationship_extraction,
- load_json, write_json)
-from .hash import compute_content_hash, compute_args_hash
-from .detect_lang import detect_main_language, detect_if_chinese
from .calculate_confidence import yes_no_loss_entropy
+from .detect_lang import detect_if_chinese, detect_main_language
+from .file import read_file
+from .format import (
+ format_generation_results,
+ handle_single_entity_extraction,
+ handle_single_relationship_extraction,
+ load_json,
+ pack_history_conversations,
+ split_string_by_multi_markers,
+ write_json,
+)
+from .hash import compute_args_hash, compute_content_hash
from .help_nltk import NLTKHelper
+from .log import logger, parse_log, set_logger
+from .loop import create_event_loop
diff --git a/graphgen/utils/file.py b/graphgen/utils/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..1129861677b4a0f8377b807ba4aca435b5253112
--- /dev/null
+++ b/graphgen/utils/file.py
@@ -0,0 +1,24 @@
+import json
+
+
+def read_file(input_file: str) -> list:
+ """
+ Read data from a file based on the specified data type.
+ :param input_file
+ :return:
+ """
+
+ if input_file.endswith(".jsonl"):
+ with open(input_file, "r", encoding="utf-8") as f:
+ data = [json.loads(line) for line in f]
+ elif input_file.endswith(".json"):
+ with open(input_file, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ elif input_file.endswith(".txt"):
+ with open(input_file, "r", encoding="utf-8") as f:
+ data = [line.strip() for line in f if line.strip()]
+ data = [{"content": line} for line in data]
+ else:
+ raise ValueError(f"Unsupported file format: {input_file}")
+
+ return data
diff --git a/graphgen/utils/format.py b/graphgen/utils/format.py
index 0a0c101d4a5b3badf7bf33f03d1e34cf86b27164..abc34c874a5b413a478e513d9f5109241f36c8a8 100644
--- a/graphgen/utils/format.py
+++ b/graphgen/utils/format.py
@@ -1,16 +1,19 @@
-import re
-import os
-import json
import html
-
+import json
+import os
+import re
from typing import Any
+from .log import logger
+
+
def pack_history_conversations(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
+
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
@@ -18,6 +21,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
+
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
@@ -30,6 +34,7 @@ def clean_str(input: Any) -> str:
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
+
async def handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
@@ -50,9 +55,11 @@ async def handle_single_entity_extraction(
"source_id": entity_source_id,
}
+
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
+
async def handle_single_relationship_extraction(
record_attributes: list[str],
chunk_key: str,
@@ -72,14 +79,56 @@ async def handle_single_relationship_extraction(
"source_id": edge_source_id,
}
+
def load_json(file_name):
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
return json.load(f)
+
def write_json(json_obj, file_name):
if not os.path.exists(os.path.dirname(file_name)):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=4, ensure_ascii=False)
+
+
+def format_generation_results(
+ results: dict[str, Any], output_data_format: str
+) -> list[dict[str, Any]]:
+ if output_data_format == "Alpaca":
+ logger.info("Output data format: Alpaca")
+ results = [
+ {
+ "instruction": item["question"],
+ "input": "",
+ "output": item["answer"],
+ }
+ for item in list(results.values())
+ ]
+ elif output_data_format == "Sharegpt":
+ logger.info("Output data format: Sharegpt")
+ results = [
+ {
+ "conversations": [
+ {"from": "human", "value": item["question"]},
+ {"from": "gpt", "value": item["answer"]},
+ ]
+ }
+ for item in list(results.values())
+ ]
+ elif output_data_format == "ChatML":
+ logger.info("Output data format: ChatML")
+ results = [
+ {
+ "messages": [
+ {"role": "user", "content": item["question"]},
+ {"role": "assistant", "content": item["answer"]},
+ ]
+ }
+ for item in list(results.values())
+ ]
+ else:
+ raise ValueError(f"Unknown output data format: {output_data_format}")
+ return results
diff --git a/hf-repo/LICENSE b/hf-repo/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac
--- /dev/null
+++ b/hf-repo/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/hf-repo/app.py b/hf-repo/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..10a914cefb4727756c1d93ec98fb20e3fbf0c081
--- /dev/null
+++ b/hf-repo/app.py
@@ -0,0 +1,586 @@
+import json
+import os
+import sys
+import tempfile
+
+import gradio as gr
+import pandas as pd
+from base import GraphGenParams
+from cache_utils import cleanup_workspace, setup_workspace
+from count_tokens import count_tokens
+from gradio_i18n import Translate
+from gradio_i18n import gettext as _
+from test_api import test_api_connection
+
+# pylint: disable=wrong-import-position
+root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.append(root_dir)
+
+from graphgen.graphgen import GraphGen
+from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
+from graphgen.models.llm.limitter import RPM, TPM
+from graphgen.utils import set_logger
+
+css = """
+.center-row {
+ display: flex;
+ justify-content: center;
+ align-items: center;
+}
+"""
+
+
+def init_graph_gen(config: dict, env: dict) -> GraphGen:
+ # Set up working directory
+ log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
+
+ set_logger(log_file, if_stream=False)
+ graph_gen = GraphGen(working_dir=working_dir)
+
+ # Set up LLM clients
+ graph_gen.synthesizer_llm_client = OpenAIModel(
+ model_name=env.get("SYNTHESIZER_MODEL", ""),
+ base_url=env.get("SYNTHESIZER_BASE_URL", ""),
+ api_key=env.get("SYNTHESIZER_API_KEY", ""),
+ request_limit=True,
+ rpm=RPM(env.get("RPM", 1000)),
+ tpm=TPM(env.get("TPM", 50000)),
+ )
+
+ graph_gen.trainee_llm_client = OpenAIModel(
+ model_name=env.get("TRAINEE_MODEL", ""),
+ base_url=env.get("TRAINEE_BASE_URL", ""),
+ api_key=env.get("TRAINEE_API_KEY", ""),
+ request_limit=True,
+ rpm=RPM(env.get("RPM", 1000)),
+ tpm=TPM(env.get("TPM", 50000)),
+ )
+
+ graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
+
+ strategy_config = config.get("traverse_strategy", {})
+ graph_gen.traverse_strategy = TraverseStrategy(
+ qa_form=strategy_config.get("qa_form"),
+ expand_method=strategy_config.get("expand_method"),
+ bidirectional=strategy_config.get("bidirectional"),
+ max_extra_edges=strategy_config.get("max_extra_edges"),
+ max_tokens=strategy_config.get("max_tokens"),
+ max_depth=strategy_config.get("max_depth"),
+ edge_sampling=strategy_config.get("edge_sampling"),
+ isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
+ loss_strategy=str(strategy_config.get("loss_strategy")),
+ )
+
+ return graph_gen
+
+
+# pylint: disable=too-many-statements
+def run_graphgen(params, progress=gr.Progress()):
+ def sum_tokens(client):
+ return sum(u["total_tokens"] for u in client.token_usage)
+
+ config = {
+ "if_trainee_model": params.if_trainee_model,
+ "input_file": params.input_file,
+ "tokenizer": params.tokenizer,
+ "quiz_samples": params.quiz_samples,
+ "traverse_strategy": {
+ "qa_form": params.qa_form,
+ "bidirectional": params.bidirectional,
+ "expand_method": params.expand_method,
+ "max_extra_edges": params.max_extra_edges,
+ "max_tokens": params.max_tokens,
+ "max_depth": params.max_depth,
+ "edge_sampling": params.edge_sampling,
+ "isolated_node_strategy": params.isolated_node_strategy,
+ "loss_strategy": params.loss_strategy,
+ },
+ "chunk_size": params.chunk_size,
+ }
+
+ env = {
+ "SYNTHESIZER_BASE_URL": params.synthesizer_url,
+ "SYNTHESIZER_MODEL": params.synthesizer_model,
+ "TRAINEE_BASE_URL": params.trainee_url,
+ "TRAINEE_MODEL": params.trainee_model,
+ "SYNTHESIZER_API_KEY": params.api_key,
+ "TRAINEE_API_KEY": params.trainee_api_key,
+ "RPM": params.rpm,
+ "TPM": params.tpm,
+ }
+
+ # Test API connection
+ test_api_connection(
+ env["SYNTHESIZER_BASE_URL"],
+ env["SYNTHESIZER_API_KEY"],
+ env["SYNTHESIZER_MODEL"],
+ )
+ if config["if_trainee_model"]:
+ test_api_connection(
+ env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
+ )
+
+ # Initialize GraphGen
+ graph_gen = init_graph_gen(config, env)
+ graph_gen.clear()
+
+ graph_gen.progress_bar = progress
+
+ try:
+ # Load input data
+ file = config["input_file"]
+ if isinstance(file, list):
+ file = file[0]
+
+ data = []
+
+ if file.endswith(".jsonl"):
+ data_type = "raw"
+ with open(file, "r", encoding="utf-8") as f:
+ data.extend(json.loads(line) for line in f)
+ elif file.endswith(".json"):
+ data_type = "chunked"
+ with open(file, "r", encoding="utf-8") as f:
+ data.extend(json.load(f))
+ elif file.endswith(".txt"):
+ # 读取文件后根据chunk_size转成raw格式的数据
+ data_type = "raw"
+ content = ""
+ with open(file, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+ for line in lines:
+ content += line.strip() + " "
+ size = int(config.get("chunk_size", 512))
+ chunks = [content[i : i + size] for i in range(0, len(content), size)]
+ data.extend([{"content": chunk} for chunk in chunks])
+ else:
+ raise ValueError(f"Unsupported file type: {file}")
+
+ # Process the data
+ graph_gen.insert(data, data_type)
+
+ if config["if_trainee_model"]:
+ # Generate quiz
+ graph_gen.quiz(max_samples=config["quiz_samples"])
+
+ # Judge statements
+ graph_gen.judge()
+ else:
+ graph_gen.traverse_strategy.edge_sampling = "random"
+ # Skip judge statements
+ graph_gen.judge(skip=True)
+
+ # Traverse graph
+ graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
+
+ # Save output
+ output_data = graph_gen.qa_storage.data
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
+ ) as tmpfile:
+ json.dump(output_data, tmpfile, ensure_ascii=False)
+ output_file = tmpfile.name
+
+ synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
+ trainee_tokens = (
+ sum_tokens(graph_gen.trainee_llm_client)
+ if config["if_trainee_model"]
+ else 0
+ )
+ total_tokens = synthesizer_tokens + trainee_tokens
+
+ data_frame = params.token_counter
+ try:
+ _update_data = [
+ [data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
+ ]
+ new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
+ data_frame = new_df
+
+ except Exception as e:
+ raise gr.Error(f"DataFrame operation error: {str(e)}")
+
+ return output_file, gr.DataFrame(
+ label="Token Stats",
+ headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
+ datatype="str",
+ interactive=False,
+ value=data_frame,
+ visible=True,
+ wrap=True,
+ )
+
+ except Exception as e: # pylint: disable=broad-except
+ raise gr.Error(f"Error occurred: {str(e)}")
+
+ finally:
+ # Clean up workspace
+ cleanup_workspace(graph_gen.working_dir)
+
+
+with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
+ # Header
+ gr.Image(
+ value=os.path.join(root_dir, "resources", "images", "logo.png"),
+ label="GraphGen Banner",
+ elem_id="banner",
+ interactive=False,
+ container=False,
+ show_download_button=False,
+ show_fullscreen_button=False,
+ )
+ lang_btn = gr.Radio(
+ choices=[
+ ("English", "en"),
+ ("简体中文", "zh"),
+ ],
+ value="en",
+ # label=_("Language"),
+ render=False,
+ container=False,
+ elem_classes=["center-row"],
+ )
+
+ gr.HTML(
+ """
+
+ """
+ )
+ with Translate(
+ os.path.join(root_dir, "webui", "translation.json"),
+ lang_btn,
+ placeholder_langs=["en", "zh"],
+ persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
+ ):
+ lang_btn.render()
+
+ gr.Markdown(
+ value="# "
+ + _("Title")
+ + "\n\n"
+ + "### [GraphGen](https://github.com/open-sciencelab/GraphGen) "
+ + _("Intro")
+ )
+
+ if_trainee_model = gr.Checkbox(
+ label=_("Use Trainee Model"), value=False, interactive=True
+ )
+
+ with gr.Accordion(label=_("Model Config"), open=False):
+ synthesizer_url = gr.Textbox(
+ label="Synthesizer URL",
+ value="https://api.siliconflow.cn/v1",
+ info=_("Synthesizer URL Info"),
+ interactive=True,
+ )
+ synthesizer_model = gr.Textbox(
+ label="Synthesizer Model",
+ value="Qwen/Qwen2.5-7B-Instruct",
+ info=_("Synthesizer Model Info"),
+ interactive=True,
+ )
+ trainee_url = gr.Textbox(
+ label="Trainee URL",
+ value="https://api.siliconflow.cn/v1",
+ info=_("Trainee URL Info"),
+ interactive=True,
+ visible=if_trainee_model.value is True,
+ )
+ trainee_model = gr.Textbox(
+ label="Trainee Model",
+ value="Qwen/Qwen2.5-7B-Instruct",
+ info=_("Trainee Model Info"),
+ interactive=True,
+ visible=if_trainee_model.value is True,
+ )
+ trainee_api_key = gr.Textbox(
+ label=_("SiliconFlow Token for Trainee Model"),
+ type="password",
+ value="",
+ info="https://cloud.siliconflow.cn/account/ak",
+ visible=if_trainee_model.value is True,
+ )
+
+ with gr.Accordion(label=_("Generation Config"), open=False):
+ chunk_size = gr.Slider(
+ label="Chunk Size",
+ minimum=256,
+ maximum=4096,
+ value=512,
+ step=256,
+ interactive=True,
+ )
+ tokenizer = gr.Textbox(
+ label="Tokenizer", value="cl100k_base", interactive=True
+ )
+ qa_form = gr.Radio(
+ choices=["atomic", "multi_hop", "aggregated"],
+ label="QA Form",
+ value="aggregated",
+ interactive=True,
+ )
+ quiz_samples = gr.Number(
+ label="Quiz Samples",
+ value=2,
+ minimum=1,
+ interactive=True,
+ visible=if_trainee_model.value is True,
+ )
+ bidirectional = gr.Checkbox(
+ label="Bidirectional", value=True, interactive=True
+ )
+
+ expand_method = gr.Radio(
+ choices=["max_width", "max_tokens"],
+ label="Expand Method",
+ value="max_tokens",
+ interactive=True,
+ )
+ max_extra_edges = gr.Slider(
+ minimum=1,
+ maximum=10,
+ value=5,
+ label="Max Extra Edges",
+ step=1,
+ interactive=True,
+ visible=expand_method.value == "max_width",
+ )
+ max_tokens = gr.Slider(
+ minimum=64,
+ maximum=1024,
+ value=256,
+ label="Max Tokens",
+ step=64,
+ interactive=True,
+ visible=(expand_method.value != "max_width"),
+ )
+
+ max_depth = gr.Slider(
+ minimum=1,
+ maximum=5,
+ value=2,
+ label="Max Depth",
+ step=1,
+ interactive=True,
+ )
+ edge_sampling = gr.Radio(
+ choices=["max_loss", "min_loss", "random"],
+ label="Edge Sampling",
+ value="max_loss",
+ interactive=True,
+ visible=if_trainee_model.value is True,
+ )
+ isolated_node_strategy = gr.Radio(
+ choices=["add", "ignore"],
+ label="Isolated Node Strategy",
+ value="ignore",
+ interactive=True,
+ )
+ loss_strategy = gr.Radio(
+ choices=["only_edge", "both"],
+ label="Loss Strategy",
+ value="only_edge",
+ interactive=True,
+ )
+
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=3):
+ api_key = gr.Textbox(
+ label=_("SiliconFlow Token"),
+ type="password",
+ value="",
+ info="https://cloud.siliconflow.cn/account/ak",
+ )
+ with gr.Column(scale=1):
+ test_connection_btn = gr.Button(_("Test Connection"))
+
+ with gr.Blocks():
+ with gr.Row(equal_height=True):
+ with gr.Column():
+ rpm = gr.Slider(
+ label="RPM",
+ minimum=10,
+ maximum=10000,
+ value=1000,
+ step=100,
+ interactive=True,
+ visible=True,
+ )
+ with gr.Column():
+ tpm = gr.Slider(
+ label="TPM",
+ minimum=5000,
+ maximum=5000000,
+ value=50000,
+ step=1000,
+ interactive=True,
+ visible=True,
+ )
+
+ with gr.Blocks():
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=1):
+ upload_file = gr.File(
+ label=_("Upload File"),
+ file_count="single",
+ file_types=[".txt", ".json", ".jsonl"],
+ interactive=True,
+ )
+ examples_dir = os.path.join(root_dir, "webui", "examples")
+ gr.Examples(
+ examples=[
+ [os.path.join(examples_dir, "txt_demo.txt")],
+ [os.path.join(examples_dir, "raw_demo.jsonl")],
+ [os.path.join(examples_dir, "chunked_demo.json")],
+ ],
+ inputs=upload_file,
+ label=_("Example Files"),
+ examples_per_page=3,
+ )
+ with gr.Column(scale=1):
+ output = gr.File(
+ label="Output(See Github FAQ)",
+ file_count="single",
+ interactive=False,
+ )
+
+ with gr.Blocks():
+ token_counter = gr.DataFrame(
+ label="Token Stats",
+ headers=[
+ "Source Text Token Count",
+ "Estimated Token Usage",
+ "Token Used",
+ ],
+ datatype="str",
+ interactive=False,
+ visible=False,
+ wrap=True,
+ )
+
+ submit_btn = gr.Button(_("Run GraphGen"))
+
+ # Test Connection
+ test_connection_btn.click(
+ test_api_connection,
+ inputs=[synthesizer_url, api_key, synthesizer_model],
+ outputs=[],
+ )
+
+ if if_trainee_model.value:
+ test_connection_btn.click(
+ test_api_connection,
+ inputs=[trainee_url, api_key, trainee_model],
+ outputs=[],
+ )
+
+ expand_method.change(
+ lambda method: (
+ gr.update(visible=method == "max_width"),
+ gr.update(visible=method != "max_width"),
+ ),
+ inputs=expand_method,
+ outputs=[max_extra_edges, max_tokens],
+ )
+
+ if_trainee_model.change(
+ lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
+ inputs=if_trainee_model,
+ outputs=[
+ trainee_url,
+ trainee_model,
+ quiz_samples,
+ edge_sampling,
+ trainee_api_key,
+ ],
+ )
+
+ upload_file.change(
+ lambda x: (gr.update(visible=True)),
+ inputs=[upload_file],
+ outputs=[token_counter],
+ ).then(
+ count_tokens,
+ inputs=[upload_file, tokenizer, token_counter],
+ outputs=[token_counter],
+ )
+
+ # run GraphGen
+ submit_btn.click(
+ lambda x: (gr.update(visible=False)),
+ inputs=[token_counter],
+ outputs=[token_counter],
+ )
+
+ submit_btn.click(
+ lambda *args: run_graphgen(
+ GraphGenParams(
+ if_trainee_model=args[0],
+ input_file=args[1],
+ tokenizer=args[2],
+ qa_form=args[3],
+ bidirectional=args[4],
+ expand_method=args[5],
+ max_extra_edges=args[6],
+ max_tokens=args[7],
+ max_depth=args[8],
+ edge_sampling=args[9],
+ isolated_node_strategy=args[10],
+ loss_strategy=args[11],
+ synthesizer_url=args[12],
+ synthesizer_model=args[13],
+ trainee_model=args[14],
+ api_key=args[15],
+ chunk_size=args[16],
+ rpm=args[17],
+ tpm=args[18],
+ quiz_samples=args[19],
+ trainee_url=args[20],
+ trainee_api_key=args[21],
+ token_counter=args[22],
+ )
+ ),
+ inputs=[
+ if_trainee_model,
+ upload_file,
+ tokenizer,
+ qa_form,
+ bidirectional,
+ expand_method,
+ max_extra_edges,
+ max_tokens,
+ max_depth,
+ edge_sampling,
+ isolated_node_strategy,
+ loss_strategy,
+ synthesizer_url,
+ synthesizer_model,
+ trainee_model,
+ api_key,
+ chunk_size,
+ rpm,
+ tpm,
+ quiz_samples,
+ trainee_url,
+ trainee_api_key,
+ token_counter,
+ ],
+ outputs=[output, token_counter],
+ )
+
+if __name__ == "__main__":
+ demo.queue(api_open=False, default_concurrency_limit=2)
+ demo.launch(server_name="0.0.0.0")
diff --git a/hf-repo/graphgen/__init__.py b/hf-repo/graphgen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hf-repo/graphgen/evaluate.py b/hf-repo/graphgen/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..da74a308743149ad1ea31f8c43f8e01c8c160a1d
--- /dev/null
+++ b/hf-repo/graphgen/evaluate.py
@@ -0,0 +1,142 @@
+"""Evaluate the quality of the generated text using various metrics"""
+
+import os
+import json
+import argparse
+import pandas as pd
+from dotenv import load_dotenv
+from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, TextPair, UniEvaluator
+from .utils import logger, set_logger
+
+sys_path = os.path.abspath(os.path.dirname(__file__))
+set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
+
+load_dotenv()
+
+def evaluate_length(corpus, tokenizer_name):
+ length_evaluator = LengthEvaluator(
+ tokenizer_name=tokenizer_name
+ )
+ logger.info("Length evaluator loaded")
+ scores = length_evaluator.get_average_score(corpus)
+ logger.info("Length scores: %s", scores)
+ return scores
+
+def evaluate_mtld(corpus):
+ mtld_evaluator = MTLDEvaluator()
+ logger.info("MTLD evaluator loaded")
+ scores = mtld_evaluator.get_average_score(corpus)
+ logger.info("MTLD scores: %s", scores)
+ min_max_scores = mtld_evaluator.get_min_max_score(corpus)
+ logger.info("MTLD min max scores: %s", min_max_scores)
+ return scores, min_max_scores
+
+def evaluate_reward(corpus, reward_model_names):
+ scores = []
+ for reward_name in reward_model_names:
+ reward_evaluator = RewardEvaluator(
+ reward_name=reward_name
+ )
+ logger.info("Loaded reward model: %s", reward_name)
+ average_score = reward_evaluator.get_average_score(corpus)
+ logger.info("%s scores: %s", reward_name, average_score)
+ min_max_scores = reward_evaluator.get_min_max_score(corpus)
+ logger.info("%s min max scores: %s", reward_name, min_max_scores)
+ scores.append({
+ 'reward_name': reward_name.split('/')[-1],
+ 'score': average_score,
+ 'min_max_scores': min_max_scores
+ })
+ del reward_evaluator
+ clean_gpu_cache()
+ return scores
+
+def evaluate_uni(corpus, uni_model_name):
+ uni_evaluator = UniEvaluator(
+ model_name=uni_model_name
+ )
+ logger.info("Uni evaluator loaded with model %s", uni_model_name)
+ uni_scores = uni_evaluator.get_average_score(corpus)
+ for key, value in uni_scores.items():
+ logger.info("Uni %s scores: %s", key, value)
+ min_max_scores = uni_evaluator.get_min_max_score(corpus)
+ for key, value in min_max_scores.items():
+ logger.info("Uni %s min max scores: %s", key, value)
+ del uni_evaluator
+ clean_gpu_cache()
+ return (uni_scores['naturalness'], uni_scores['coherence'], uni_scores['understandability'],
+ min_max_scores['naturalness'], min_max_scores['coherence'], min_max_scores['understandability'])
+
+
+def clean_gpu_cache():
+ import torch
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+
+if __name__ == '__main__':
+ import torch.multiprocessing as mp
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--folder', type=str, default='cache/data', help='folder to load data')
+ parser.add_argument('--output', type=str, default='cache/output', help='path to save output')
+
+ parser.add_argument('--tokenizer', type=str, default='cl100k_base', help='tokenizer name')
+ parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2',
+ help='Comma-separated list of reward models')
+ parser.add_argument('--uni', type=str, default='MingZhong/unieval-sum', help='uni model name')
+
+ args = parser.parse_args()
+
+ if not os.path.exists(args.folder):
+ raise ValueError(f"Folder {args.folder} does not exist")
+
+ if not os.path.exists(args.output):
+ os.makedirs(args.output)
+
+ reward_models = args.reward.split(',')
+
+
+ results = []
+
+ logger.info("Data loaded from %s", args.folder)
+ mp.set_start_method('spawn')
+
+ for file in os.listdir(args.folder):
+ if file.endswith('.json'):
+ logger.info("Processing %s", file)
+ with open(os.path.join(args.folder, file), 'r', encoding='utf-8') as f:
+ data = json.load(f)
+ data = [TextPair(
+ question=data[key]['question'],
+ answer=data[key]['answer']
+ ) for key in data]
+
+ length_scores = evaluate_length(data, args.tokenizer)
+ mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
+ reward_scores = evaluate_reward(data, reward_models)
+ uni_naturalness_scores, uni_coherence_scores, uni_understandability_scores, \
+ min_max_uni_naturalness_scores, min_max_uni_coherence_scores, min_max_uni_understandability_scores \
+ = evaluate_uni(data, args.uni)
+
+ result = {
+ 'file': file,
+ 'number': len(data),
+ 'length': length_scores,
+ 'mtld': mtld_scores,
+ 'mtld_min_max': min_max_mtld_scores,
+ 'uni_naturalness': uni_naturalness_scores,
+ 'uni_coherence': uni_coherence_scores,
+ 'uni_understandability': uni_understandability_scores,
+ 'uni_naturalness_min_max': min_max_uni_naturalness_scores,
+ 'uni_coherence_min_max': min_max_uni_coherence_scores,
+ 'uni_understandability_min_max': min_max_uni_understandability_scores
+ }
+ for reward_score in reward_scores:
+ result[reward_score['reward_name']] = reward_score['score']
+ result[f"{reward_score['reward_name']}_min_max"] = reward_score['min_max_scores']
+
+ results.append(result)
+
+ results = pd.DataFrame(results)
+ results.to_csv(os.path.join(args.output, 'evaluation.csv'), index=False)
diff --git a/hf-repo/graphgen/generate.py b/hf-repo/graphgen/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..eec168d6137cdc2ebaed8bab3ffef74adf70ddf3
--- /dev/null
+++ b/hf-repo/graphgen/generate.py
@@ -0,0 +1,103 @@
+import argparse
+import os
+import time
+from importlib.resources import files
+
+import yaml
+from dotenv import load_dotenv
+
+from .graphgen import GraphGen
+from .utils import logger, set_logger
+
+sys_path = os.path.abspath(os.path.dirname(__file__))
+
+load_dotenv()
+
+
+def set_working_dir(folder):
+ os.makedirs(folder, exist_ok=True)
+ os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
+ os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
+
+
+def save_config(config_path, global_config):
+ if not os.path.exists(os.path.dirname(config_path)):
+ os.makedirs(os.path.dirname(config_path))
+ with open(config_path, "w", encoding="utf-8") as config_file:
+ yaml.dump(
+ global_config, config_file, default_flow_style=False, allow_unicode=True
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--config_file",
+ help="Config parameters for GraphGen.",
+ default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
+ type=str,
+ )
+ parser.add_argument(
+ "--output_dir",
+ help="Output directory for GraphGen.",
+ default=sys_path,
+ required=True,
+ type=str,
+ )
+
+ args = parser.parse_args()
+
+ working_dir = args.output_dir
+ set_working_dir(working_dir)
+
+ with open(args.config_file, "r", encoding="utf-8") as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+
+ output_data_type = config["output_data_type"]
+ unique_id = int(time.time())
+ set_logger(
+ os.path.join(
+ working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
+ ),
+ if_stream=True,
+ )
+ logger.info(
+ "GraphGen with unique ID %s logging to %s",
+ unique_id,
+ os.path.join(
+ working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
+ ),
+ )
+
+ graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
+
+ graph_gen.insert()
+
+ if config["search"]["enabled"]:
+ graph_gen.search()
+
+ # Use pipeline according to the output data type
+ if output_data_type in ["atomic", "aggregated", "multi_hop"]:
+ if "quiz_and_judge_strategy" in config and config[
+ "quiz_and_judge_strategy"
+ ].get("enabled", False):
+ graph_gen.quiz()
+ graph_gen.judge()
+ else:
+ logger.warning(
+ "Quiz and Judge strategy is disabled. Edge sampling falls back to random."
+ )
+ graph_gen.traverse_strategy.edge_sampling = "random"
+ graph_gen.traverse()
+ elif output_data_type == "cot":
+ graph_gen.generate_reasoning(method_params=config["method_params"])
+ else:
+ raise ValueError(f"Unsupported output data type: {output_data_type}")
+
+ output_path = os.path.join(working_dir, "data", "graphgen", str(unique_id))
+ save_config(os.path.join(output_path, f"config-{unique_id}.yaml"), config)
+ logger.info("GraphGen completed successfully. Data saved to %s", output_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/hf-repo/graphgen/graphgen.py b/hf-repo/graphgen/graphgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b7b302ac5c2034b631689c3313389aa32160572
--- /dev/null
+++ b/hf-repo/graphgen/graphgen.py
@@ -0,0 +1,395 @@
+import asyncio
+import os
+import time
+from dataclasses import dataclass, field
+from typing import Dict, List, Union, cast
+
+import gradio as gr
+from tqdm.asyncio import tqdm as tqdm_async
+
+from .models import (
+ Chunk,
+ JsonKVStorage,
+ JsonListStorage,
+ NetworkXStorage,
+ OpenAIModel,
+ Tokenizer,
+ TraverseStrategy,
+)
+from .models.storage.base_storage import StorageNameSpace
+from .operators import (
+ extract_kg,
+ generate_cot,
+ judge_statement,
+ quiz,
+ search_all,
+ traverse_graph_atomically,
+ traverse_graph_by_edge,
+ traverse_graph_for_multi_hop,
+)
+from .utils import (
+ compute_content_hash,
+ create_event_loop,
+ format_generation_results,
+ logger,
+ read_file,
+)
+
+sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+
+
+@dataclass
+class GraphGen:
+ unique_id: int = int(time.time())
+ working_dir: str = os.path.join(sys_path, "cache")
+ config: Dict = field(default_factory=dict)
+
+ # llm
+ tokenizer_instance: Tokenizer = None
+ synthesizer_llm_client: OpenAIModel = None
+ trainee_llm_client: OpenAIModel = None
+
+ # text chunking
+ # TODO: make it configurable
+ chunk_size: int = 1024
+ chunk_overlap_size: int = 100
+
+ # search
+ search_config: dict = field(
+ default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
+ )
+
+ # traversal
+ traverse_strategy: TraverseStrategy = None
+
+ # webui
+ progress_bar: gr.Progress = None
+
+ def __post_init__(self):
+ self.tokenizer_instance: Tokenizer = Tokenizer(
+ model_name=self.config["tokenizer"]
+ )
+ self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
+ model_name=os.getenv("SYNTHESIZER_MODEL"),
+ api_key=os.getenv("SYNTHESIZER_API_KEY"),
+ base_url=os.getenv("SYNTHESIZER_BASE_URL"),
+ tokenizer_instance=self.tokenizer_instance,
+ )
+ self.trainee_llm_client: OpenAIModel = OpenAIModel(
+ model_name=os.getenv("TRAINEE_MODEL"),
+ api_key=os.getenv("TRAINEE_API_KEY"),
+ base_url=os.getenv("TRAINEE_BASE_URL"),
+ tokenizer_instance=self.tokenizer_instance,
+ )
+ self.search_config = self.config["search"]
+
+ if "traverse_strategy" in self.config:
+ self.traverse_strategy = TraverseStrategy(
+ **self.config["traverse_strategy"]
+ )
+
+ self.full_docs_storage: JsonKVStorage = JsonKVStorage(
+ self.working_dir, namespace="full_docs"
+ )
+ self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
+ self.working_dir, namespace="text_chunks"
+ )
+ self.graph_storage: NetworkXStorage = NetworkXStorage(
+ self.working_dir, namespace="graph"
+ )
+ self.search_storage: JsonKVStorage = JsonKVStorage(
+ self.working_dir, namespace="search"
+ )
+ self.rephrase_storage: JsonKVStorage = JsonKVStorage(
+ self.working_dir, namespace="rephrase"
+ )
+ self.qa_storage: JsonListStorage = JsonListStorage(
+ os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)),
+ namespace=f"qa-{self.unique_id}",
+ )
+
+ async def async_split_chunks(
+ self, data: List[Union[List, Dict]], data_type: str
+ ) -> dict:
+ # TODO: configurable whether to use coreference resolution
+ if len(data) == 0:
+ return {}
+
+ inserting_chunks = {}
+ if data_type == "raw":
+ assert isinstance(data, list) and isinstance(data[0], dict)
+ # compute hash for each document
+ new_docs = {
+ compute_content_hash(doc["content"], prefix="doc-"): {
+ "content": doc["content"]
+ }
+ for doc in data
+ }
+ _add_doc_keys = await self.full_docs_storage.filter_keys(
+ list(new_docs.keys())
+ )
+ new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
+ if len(new_docs) == 0:
+ logger.warning("All docs are already in the storage")
+ return {}
+ logger.info("[New Docs] inserting %d docs", len(new_docs))
+
+ cur_index = 1
+ doc_number = len(new_docs)
+ async for doc_key, doc in tqdm_async(
+ new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
+ ):
+ chunks = {
+ compute_content_hash(dp["content"], prefix="chunk-"): {
+ **dp,
+ "full_doc_id": doc_key,
+ }
+ for dp in self.tokenizer_instance.chunk_by_token_size(
+ doc["content"], self.chunk_overlap_size, self.chunk_size
+ )
+ }
+ inserting_chunks.update(chunks)
+
+ if self.progress_bar is not None:
+ self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
+ cur_index += 1
+
+ _add_chunk_keys = await self.text_chunks_storage.filter_keys(
+ list(inserting_chunks.keys())
+ )
+ inserting_chunks = {
+ k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
+ }
+ elif data_type == "chunked":
+ assert isinstance(data, list) and isinstance(data[0], list)
+ new_docs = {
+ compute_content_hash("".join(chunk["content"]), prefix="doc-"): {
+ "content": "".join(chunk["content"])
+ }
+ for doc in data
+ for chunk in doc
+ }
+ _add_doc_keys = await self.full_docs_storage.filter_keys(
+ list(new_docs.keys())
+ )
+ new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
+ if len(new_docs) == 0:
+ logger.warning("All docs are already in the storage")
+ return {}
+ logger.info("[New Docs] inserting %d docs", len(new_docs))
+ async for doc in tqdm_async(
+ data, desc="[1/4]Chunking documents", unit="doc"
+ ):
+ doc_str = "".join([chunk["content"] for chunk in doc])
+ for chunk in doc:
+ chunk_key = compute_content_hash(chunk["content"], prefix="chunk-")
+ inserting_chunks[chunk_key] = {
+ **chunk,
+ "full_doc_id": compute_content_hash(doc_str, prefix="doc-"),
+ }
+ _add_chunk_keys = await self.text_chunks_storage.filter_keys(
+ list(inserting_chunks.keys())
+ )
+ inserting_chunks = {
+ k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
+ }
+ else:
+ raise ValueError(f"Unknown data type: {data_type}")
+
+ await self.full_docs_storage.upsert(new_docs)
+ await self.text_chunks_storage.upsert(inserting_chunks)
+
+ return inserting_chunks
+
+ def insert(self):
+ loop = create_event_loop()
+ loop.run_until_complete(self.async_insert())
+
+ async def async_insert(self):
+ """
+ insert chunks into the graph
+ """
+
+ input_file = self.config["input_file"]
+ data_type = self.config["input_data_type"]
+ data = read_file(input_file)
+
+ inserting_chunks = await self.async_split_chunks(data, data_type)
+
+ if len(inserting_chunks) == 0:
+ logger.warning("All chunks are already in the storage")
+ return
+ logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
+
+ logger.info("[Entity and Relation Extraction]...")
+ _add_entities_and_relations = await extract_kg(
+ llm_client=self.synthesizer_llm_client,
+ kg_instance=self.graph_storage,
+ tokenizer_instance=self.tokenizer_instance,
+ chunks=[
+ Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
+ ],
+ progress_bar=self.progress_bar,
+ )
+ if not _add_entities_and_relations:
+ logger.warning("No entities or relations extracted")
+ return
+
+ await self._insert_done()
+
+ async def _insert_done(self):
+ tasks = []
+ for storage_instance in [
+ self.full_docs_storage,
+ self.text_chunks_storage,
+ self.graph_storage,
+ self.search_storage,
+ ]:
+ if storage_instance is None:
+ continue
+ tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
+ await asyncio.gather(*tasks)
+
+ def search(self):
+ loop = create_event_loop()
+ loop.run_until_complete(self.async_search())
+
+ async def async_search(self):
+ logger.info(
+ "Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
+ )
+ if self.search_config["enabled"]:
+ logger.info(
+ "[Search] %s ...", ", ".join(self.search_config["search_types"])
+ )
+ all_nodes = await self.graph_storage.get_all_nodes()
+ all_nodes_names = [node[0] for node in all_nodes]
+ new_search_entities = await self.full_docs_storage.filter_keys(
+ all_nodes_names
+ )
+ logger.info(
+ "[Search] Found %d entities to search", len(new_search_entities)
+ )
+ _add_search_data = await search_all(
+ search_types=self.search_config["search_types"],
+ search_entities=new_search_entities,
+ )
+ if _add_search_data:
+ await self.search_storage.upsert(_add_search_data)
+ logger.info("[Search] %d entities searched", len(_add_search_data))
+
+ # Format search results for inserting
+ search_results = []
+ for _, search_data in _add_search_data.items():
+ search_results.extend(
+ [
+ {"content": search_data[key]}
+ for key in list(search_data.keys())
+ ]
+ )
+ # TODO: fix insert after search
+ await self.async_insert()
+
+ def quiz(self):
+ loop = create_event_loop()
+ loop.run_until_complete(self.async_quiz())
+
+ async def async_quiz(self):
+ max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
+ await quiz(
+ self.synthesizer_llm_client,
+ self.graph_storage,
+ self.rephrase_storage,
+ max_samples,
+ )
+ await self.rephrase_storage.index_done_callback()
+
+ def judge(self):
+ loop = create_event_loop()
+ loop.run_until_complete(self.async_judge())
+
+ async def async_judge(self):
+ re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
+ _update_relations = await judge_statement(
+ self.trainee_llm_client,
+ self.graph_storage,
+ self.rephrase_storage,
+ re_judge,
+ )
+ await _update_relations.index_done_callback()
+
+ def traverse(self):
+ loop = create_event_loop()
+ loop.run_until_complete(self.async_traverse())
+
+ async def async_traverse(self):
+ output_data_type = self.config["output_data_type"]
+
+ if output_data_type == "atomic":
+ results = await traverse_graph_atomically(
+ self.synthesizer_llm_client,
+ self.tokenizer_instance,
+ self.graph_storage,
+ self.traverse_strategy,
+ self.text_chunks_storage,
+ self.progress_bar,
+ )
+ elif output_data_type == "multi_hop":
+ results = await traverse_graph_for_multi_hop(
+ self.synthesizer_llm_client,
+ self.tokenizer_instance,
+ self.graph_storage,
+ self.traverse_strategy,
+ self.text_chunks_storage,
+ self.progress_bar,
+ )
+ elif output_data_type == "aggregated":
+ results = await traverse_graph_by_edge(
+ self.synthesizer_llm_client,
+ self.tokenizer_instance,
+ self.graph_storage,
+ self.traverse_strategy,
+ self.text_chunks_storage,
+ self.progress_bar,
+ )
+ else:
+ raise ValueError(f"Unknown qa_form: {output_data_type}")
+
+ results = format_generation_results(
+ results, output_data_format=self.config["output_data_format"]
+ )
+
+ await self.qa_storage.upsert(results)
+ await self.qa_storage.index_done_callback()
+
+ def generate_reasoning(self, method_params):
+ loop = create_event_loop()
+ loop.run_until_complete(self.async_generate_reasoning(method_params))
+
+ async def async_generate_reasoning(self, method_params):
+ results = await generate_cot(
+ self.graph_storage,
+ self.synthesizer_llm_client,
+ method_params=method_params,
+ )
+
+ results = format_generation_results(
+ results, output_data_format=self.config["output_data_format"]
+ )
+
+ await self.qa_storage.upsert(results)
+ await self.qa_storage.index_done_callback()
+
+ def clear(self):
+ loop = create_event_loop()
+ loop.run_until_complete(self.async_clear())
+
+ async def async_clear(self):
+ await self.full_docs_storage.drop()
+ await self.text_chunks_storage.drop()
+ await self.search_storage.drop()
+ await self.graph_storage.clear()
+ await self.rephrase_storage.drop()
+ await self.qa_storage.drop()
+
+ logger.info("All caches are cleared")
diff --git a/hf-repo/graphgen/judge.py b/hf-repo/graphgen/judge.py
new file mode 100644
index 0000000000000000000000000000000000000000..f05bdf1da816a0ba07ca682f300f968d48b29dd1
--- /dev/null
+++ b/hf-repo/graphgen/judge.py
@@ -0,0 +1,60 @@
+import os
+import argparse
+import asyncio
+from dotenv import load_dotenv
+
+from .models import NetworkXStorage, JsonKVStorage, OpenAIModel
+from .operators import judge_statement
+
+sys_path = os.path.abspath(os.path.dirname(__file__))
+
+load_dotenv()
+
+def calculate_average_loss(graph: NetworkXStorage):
+ """
+ Calculate the average loss of the graph.
+
+ :param graph: NetworkXStorage
+ :return: float
+ """
+ edges = asyncio.run(graph.get_all_edges())
+ total_loss = 0
+ for edge in edges:
+ total_loss += edge[2]['loss']
+ return total_loss / len(edges)
+
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input', type=str, default=os.path.join(sys_path, "cache"), help='path to load input graph')
+ parser.add_argument('--output', type=str, default='cache/output/new_graph.graphml', help='path to save output')
+
+ args = parser.parse_args()
+
+ llm_client = OpenAIModel(
+ model_name=os.getenv("TRAINEE_MODEL"),
+ api_key=os.getenv("TRAINEE_API_KEY"),
+ base_url=os.getenv("TRAINEE_BASE_URL")
+ )
+
+ graph_storage = NetworkXStorage(
+ args.input,
+ namespace="graph"
+ )
+ average_loss = calculate_average_loss(graph_storage)
+ print(f"Average loss of the graph: {average_loss}")
+
+ rephrase_storage = JsonKVStorage(
+ os.path.join(sys_path, "cache"),
+ namespace="rephrase"
+ )
+
+ new_graph = asyncio.run(judge_statement(llm_client, graph_storage, rephrase_storage, re_judge=True))
+
+ graph_file = asyncio.run(graph_storage.get_graph())
+
+ new_graph.write_nx_graph(graph_file, args.output)
+
+ average_loss = calculate_average_loss(new_graph)
+ print(f"Average loss of the graph: {average_loss}")
diff --git a/hf-repo/graphgen/models/__init__.py b/hf-repo/graphgen/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f715335874578b32366228dc9a67a3703073bbe3
--- /dev/null
+++ b/hf-repo/graphgen/models/__init__.py
@@ -0,0 +1,45 @@
+from .community.community_detector import CommunityDetector
+from .evaluate.length_evaluator import LengthEvaluator
+from .evaluate.mtld_evaluator import MTLDEvaluator
+from .evaluate.reward_evaluator import RewardEvaluator
+from .evaluate.uni_evaluator import UniEvaluator
+from .llm.openai_model import OpenAIModel
+from .llm.tokenizer import Tokenizer
+from .llm.topk_token_model import Token, TopkTokenModel
+from .search.db.uniprot_search import UniProtSearch
+from .search.kg.wiki_search import WikiSearch
+from .search.web.bing_search import BingSearch
+from .search.web.google_search import GoogleSearch
+from .storage.json_storage import JsonKVStorage, JsonListStorage
+from .storage.networkx_storage import NetworkXStorage
+from .strategy.travserse_strategy import TraverseStrategy
+from .text.chunk import Chunk
+from .text.text_pair import TextPair
+
+__all__ = [
+ # llm models
+ "OpenAIModel",
+ "TopkTokenModel",
+ "Token",
+ "Tokenizer",
+ # storage models
+ "Chunk",
+ "NetworkXStorage",
+ "JsonKVStorage",
+ "JsonListStorage",
+ # search models
+ "WikiSearch",
+ "GoogleSearch",
+ "BingSearch",
+ "UniProtSearch",
+ # evaluate models
+ "TextPair",
+ "LengthEvaluator",
+ "MTLDEvaluator",
+ "RewardEvaluator",
+ "UniEvaluator",
+ # strategy models
+ "TraverseStrategy",
+ # community models
+ "CommunityDetector",
+]
diff --git a/hf-repo/graphgen/models/embed/__init__.py b/hf-repo/graphgen/models/embed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hf-repo/graphgen/models/embed/embedding.py b/hf-repo/graphgen/models/embed/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..8213b90f08820414402547ead11854da4f9b308b
--- /dev/null
+++ b/hf-repo/graphgen/models/embed/embedding.py
@@ -0,0 +1,29 @@
+from dataclasses import dataclass
+import asyncio
+import numpy as np
+
+class UnlimitedSemaphore:
+ """A context manager that allows unlimited access."""
+
+ async def __aenter__(self):
+ pass
+
+ async def __aexit__(self, exc_type, exc, tb):
+ pass
+
+@dataclass
+class EmbeddingFunc:
+ embedding_dim: int
+ max_token_size: int
+ func: callable
+ concurrent_limit: int = 16
+
+ def __post_init__(self):
+ if self.concurrent_limit != 0:
+ self._semaphore = asyncio.Semaphore(self.concurrent_limit)
+ else:
+ self._semaphore = UnlimitedSemaphore()
+
+ async def __call__(self, *args, **kwargs) -> np.ndarray:
+ async with self._semaphore:
+ return await self.func(*args, **kwargs)
diff --git a/hf-repo/graphgen/models/evaluate/__init__.py b/hf-repo/graphgen/models/evaluate/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hf-repo/graphgen/models/evaluate/base_evaluator.py b/hf-repo/graphgen/models/evaluate/base_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c5ae2d5252e24322db5c89c70364b514f6f7cde
--- /dev/null
+++ b/hf-repo/graphgen/models/evaluate/base_evaluator.py
@@ -0,0 +1,51 @@
+import asyncio
+
+from dataclasses import dataclass
+from tqdm.asyncio import tqdm as tqdm_async
+from graphgen.utils import create_event_loop
+from graphgen.models.text.text_pair import TextPair
+
+@dataclass
+class BaseEvaluator:
+ max_concurrent: int = 100
+ results: list[float] = None
+
+ def evaluate(self, pairs: list[TextPair]) -> list[float]:
+ """
+ Evaluate the text and return a score.
+ """
+ return create_event_loop().run_until_complete(self.async_evaluate(pairs))
+
+ async def async_evaluate(self, pairs: list[TextPair]) -> list[float]:
+ semaphore = asyncio.Semaphore(self.max_concurrent)
+
+ async def evaluate_with_semaphore(pair):
+ async with semaphore: # 获取Semaphore
+ return await self.evaluate_single(pair)
+
+ results = []
+ for result in tqdm_async(
+ asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
+ total=len(pairs),
+ ):
+ results.append(await result)
+ return results
+
+ async def evaluate_single(self, pair: TextPair) -> float:
+ raise NotImplementedError()
+
+ def get_average_score(self, pairs: list[TextPair]) -> float:
+ """
+ Get the average score of a batch of texts.
+ """
+ results = self.evaluate(pairs)
+ self.results = results
+ return sum(self.results) / len(pairs)
+
+ def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
+ """
+ Get the min and max score of a batch of texts.
+ """
+ if self.results is None:
+ self.get_average_score(pairs)
+ return min(self.results), max(self.results)
diff --git a/hf-repo/graphgen/models/evaluate/length_evaluator.py b/hf-repo/graphgen/models/evaluate/length_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba53ff6b28b5ddfc6b20b079b326bd8f083167be
--- /dev/null
+++ b/hf-repo/graphgen/models/evaluate/length_evaluator.py
@@ -0,0 +1,22 @@
+from dataclasses import dataclass
+from graphgen.models.evaluate.base_evaluator import BaseEvaluator
+from graphgen.models.llm.tokenizer import Tokenizer
+from graphgen.models.text.text_pair import TextPair
+from graphgen.utils import create_event_loop
+
+
+@dataclass
+class LengthEvaluator(BaseEvaluator):
+ tokenizer_name: str = "cl100k_base"
+ def __post_init__(self):
+ self.tokenizer = Tokenizer(
+ model_name=self.tokenizer_name
+ )
+
+ async def evaluate_single(self, pair: TextPair) -> float:
+ loop = create_event_loop()
+ return await loop.run_in_executor(None, self._calculate_length, pair.answer)
+
+ def _calculate_length(self, text: str) -> float:
+ tokens = self.tokenizer.encode_string(text)
+ return len(tokens)
diff --git a/hf-repo/graphgen/models/evaluate/mtld_evaluator.py b/hf-repo/graphgen/models/evaluate/mtld_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ea68875e7e16aabe232f7c642a44752aa125e6c
--- /dev/null
+++ b/hf-repo/graphgen/models/evaluate/mtld_evaluator.py
@@ -0,0 +1,76 @@
+from dataclasses import dataclass, field
+from typing import Set
+
+from graphgen.models.evaluate.base_evaluator import BaseEvaluator
+from graphgen.models.text.text_pair import TextPair
+from graphgen.utils import detect_main_language, NLTKHelper, create_event_loop
+
+
+nltk_helper = NLTKHelper()
+
+@dataclass
+class MTLDEvaluator(BaseEvaluator):
+ """
+ 衡量文本词汇多样性的指标
+ """
+ stopwords_en: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("english")))
+ stopwords_zh: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("chinese")))
+
+ async def evaluate_single(self, pair: TextPair) -> float:
+ loop = create_event_loop()
+ return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
+
+ def _calculate_mtld_score(self, text: str, threshold=0.72) -> float:
+ """
+ 计算MTLD (向前和向后的平均值)
+
+ min is 1.0
+ higher is better
+ """
+ if not text or not text.strip():
+ return 0.0
+
+ lang = detect_main_language(text)
+ tokens = nltk_helper.word_tokenize(text, lang)
+
+ stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
+ filtered_tokens = [word for word in tokens if word not in stopwords]
+ filtered_tokens = [word for word in filtered_tokens if word.isalnum()]
+
+ if not filtered_tokens:
+ return 0
+
+ # 计算向前的MTLD
+ forward_factors = self._compute_factors(filtered_tokens, threshold)
+
+ # 计算向后的MTLD
+ backward_factors = self._compute_factors(filtered_tokens[::-1], threshold)
+
+ # 取平均值
+ return (forward_factors + backward_factors) / 2
+
+ @staticmethod
+ def _compute_factors(tokens: list, threshold: float) -> float:
+ factors = 0
+ current_segment = []
+ unique_words = set()
+
+ for token in tokens:
+ current_segment.append(token)
+ unique_words.add(token)
+ ttr = len(unique_words) / len(current_segment)
+
+ if ttr <= threshold:
+ factors += 1
+ current_segment = []
+ unique_words = set()
+
+ # 处理最后一个不完整片段
+ if current_segment:
+ ttr = len(unique_words) / len(current_segment)
+ if ttr <= threshold:
+ factors += 1
+ else:
+ factors += (1 - (ttr - threshold) / (1 - threshold))
+
+ return len(tokens) / factors if factors > 0 else len(tokens)
diff --git a/hf-repo/graphgen/models/evaluate/reward_evaluator.py b/hf-repo/graphgen/models/evaluate/reward_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e4c021c03cf25f15b45b734d4af2cf4da0cefde
--- /dev/null
+++ b/hf-repo/graphgen/models/evaluate/reward_evaluator.py
@@ -0,0 +1,101 @@
+from dataclasses import dataclass
+from tqdm import tqdm
+from graphgen.models.text.text_pair import TextPair
+
+
+@dataclass
+class RewardEvaluator:
+ """
+ Reward Model Evaluator.
+ OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
+ """
+ reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
+ max_length: int = 2560
+ results: list[float] = None
+
+ def __post_init__(self):
+ import torch
+ self.num_gpus = torch.cuda.device_count()
+
+ @staticmethod
+ def process_chunk(rank, pairs, reward_name, max_length, return_dict):
+ import torch
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
+ device = f'cuda:{rank}'
+ torch.cuda.set_device(rank)
+
+ rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
+ tokenizer = AutoTokenizer.from_pretrained(reward_name)
+ rank_model.to(device)
+ rank_model.eval()
+
+ results = []
+ with torch.no_grad():
+ for pair in tqdm(pairs):
+ inputs = tokenizer(
+ pair.question,
+ pair.answer,
+ return_tensors="pt",
+ max_length=max_length,
+ truncation=True
+ )
+ inputs = {k: v.to(device) for k, v in inputs.items()}
+ score = rank_model(**inputs).logits[0].item()
+ results.append(score)
+
+ return_dict[rank] = results
+
+ def evaluate(self, pairs: list[TextPair]) -> list[float]:
+ import torch.multiprocessing as mp
+ chunk_size = len(pairs) // self.num_gpus
+ chunks = []
+ for i in range(self.num_gpus):
+ start = i * chunk_size
+ end = start + chunk_size
+ if i == self.num_gpus - 1:
+ end = len(pairs)
+ chunks.append(pairs[start:end])
+
+ # multi-process
+ manager = mp.Manager()
+ return_dict = manager.dict()
+ processes = []
+
+ for rank, chunk in enumerate(chunks):
+ p = mp.Process(
+ target=self.process_chunk,
+ args=(rank, chunk, self.reward_name, self.max_length, return_dict)
+ )
+ p.start()
+ processes.append(p)
+
+ for p in processes:
+ p.join()
+
+ # 合并结果
+ results = []
+ for rank in range(len(chunks)):
+ results.extend(return_dict[rank])
+
+ for p in processes:
+ if p.is_alive():
+ p.terminate()
+ p.join()
+
+ return results
+
+ def get_average_score(self, pairs: list[TextPair]) -> float:
+ """
+ Get the average score of a batch of texts.
+ """
+ results = self.evaluate(pairs)
+ self.results = results
+ return sum(self.results) / len(pairs)
+
+ def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
+ """
+ Get the min and max score of a batch of texts.
+ """
+ if self.results is None:
+ self.get_average_score(pairs)
+ return min(self.results), max(self.results)
diff --git a/hf-repo/graphgen/models/evaluate/uni_evaluator.py b/hf-repo/graphgen/models/evaluate/uni_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a334f0a9fa8e909c54339bdc7711d00659b50ff3
--- /dev/null
+++ b/hf-repo/graphgen/models/evaluate/uni_evaluator.py
@@ -0,0 +1,159 @@
+# https://github.com/maszhongming/UniEval/tree/main
+
+from dataclasses import dataclass, field
+from tqdm import tqdm
+from graphgen.models.text.text_pair import TextPair
+
+
+def _add_questions(dimension: str, question: str, answer: str):
+ if dimension == "naturalness":
+ cur_input = 'question: Is this a natural response in the dialogue? response: ' + answer
+ elif dimension == "coherence":
+ cur_input = 'question: Is this a coherent response given the dialogue history? response: ' \
+ + answer + ' dialogue history: ' + question
+ elif dimension == "understandability":
+ cur_input = 'question: Is this an understandable response in the dialogue? response: ' + answer
+ else:
+ raise NotImplementedError(
+ 'The input format for this dimension is still undefined. Please customize it first.')
+ return cur_input
+
+@dataclass
+class UniEvaluator:
+ model_name: str = "MingZhong/unieval-sum"
+ dimensions: list = field(default_factory=lambda: ['naturalness', 'coherence', 'understandability'])
+ max_length: int = 2560
+ results: dict = None
+
+ def __post_init__(self):
+ import torch
+ self.num_gpus = torch.cuda.device_count()
+ self.results = {}
+
+ @staticmethod
+ def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
+ import torch
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+ device = f'cuda:{rank}'
+ torch.cuda.set_device(rank)
+
+ rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ rank_model.to(device)
+ rank_model.eval()
+
+ softmax = torch.nn.Softmax(dim=1)
+
+ pos_id = tokenizer("Yes")["input_ids"][0]
+ neg_id = tokenizer("No")["input_ids"][0]
+
+ results = []
+ with torch.no_grad():
+ for pair in tqdm(pairs):
+ text = _add_questions(dimension, pair.question, pair.answer)
+
+ tgt = "No"
+
+ encoded_src = tokenizer(
+ text,
+ max_length=max_length,
+ truncation=True,
+ padding=True,
+ return_tensors='pt'
+ )
+ encoded_tgt = tokenizer(
+ tgt,
+ max_length=max_length,
+ truncation=True,
+ padding=True,
+ return_tensors='pt'
+ )
+
+ src_tokens = encoded_src['input_ids'].to(device)
+ src_mask = encoded_src['attention_mask'].to(device)
+
+ tgt_tokens = encoded_tgt['input_ids'].to(device)[:, 0].unsqueeze(-1)
+
+ output = rank_model(
+ input_ids=src_tokens,
+ attention_mask=src_mask,
+ labels=tgt_tokens,
+ use_cache = False
+ )
+
+ logits = output.logits.view(-1, rank_model.config.vocab_size)
+
+ pos_score = softmax(logits)[:, pos_id] # Yes
+ neg_score = softmax(logits)[:, neg_id]
+ score = pos_score / (pos_score + neg_score)
+
+ results.append(score.item())
+
+ return_dict[rank] = results
+
+ def evaluate(self, pairs: list[TextPair]) -> list[dict]:
+ import torch.multiprocessing as mp
+ final_results = []
+ for dimension in self.dimensions:
+ chunk_size = len(pairs) // self.num_gpus
+ chunks = []
+ for i in range(self.num_gpus):
+ start = i * chunk_size
+ end = start + chunk_size
+ if i == self.num_gpus - 1:
+ end = len(pairs)
+ chunks.append(pairs[start:end])
+
+ # multi-process
+ manager = mp.Manager()
+ return_dict = manager.dict()
+ processes = []
+
+ for rank, chunk in enumerate(chunks):
+ p = mp.Process(
+ target=self.process_chunk,
+ args=(rank, chunk, self.model_name, self.max_length, dimension, return_dict)
+ )
+ p.start()
+ processes.append(p)
+
+ for p in processes:
+ p.join()
+
+ # 合并结果
+ results = []
+ for rank in range(len(chunks)):
+ results.extend(return_dict[rank])
+
+ for p in processes:
+ if p.is_alive():
+ p.terminate()
+ p.join()
+
+ final_results.append({
+ dimension: results
+ })
+ return final_results
+
+ def get_average_score(self, pairs: list[TextPair]) -> dict:
+ """
+ Get the average score of a batch of texts.
+ """
+ results = self.evaluate(pairs)
+ final_results = {}
+ for result in results:
+ for key, value in result.items():
+ final_results[key] = sum(value) / len(value)
+ self.results[key] = value
+ return final_results
+
+ def get_min_max_score(self, pairs: list[TextPair]) -> dict:
+ """
+ Get the min and max score of a batch of texts.
+ """
+ if self.results is None:
+ self.get_average_score(pairs)
+ final_results = {}
+ for key, value in self.results.items():
+ final_results[key] = min(value), max(value)
+ return final_results
diff --git a/hf-repo/graphgen/models/llm/__init__.py b/hf-repo/graphgen/models/llm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hf-repo/graphgen/models/llm/limitter.py b/hf-repo/graphgen/models/llm/limitter.py
new file mode 100644
index 0000000000000000000000000000000000000000..01cb1f709f17632652b36a1da0b21e963e823df0
--- /dev/null
+++ b/hf-repo/graphgen/models/llm/limitter.py
@@ -0,0 +1,88 @@
+import time
+from datetime import datetime, timedelta
+import asyncio
+
+from graphgen.utils import logger
+
+
+class RPM:
+
+ def __init__(self, rpm: int = 1000):
+ self.rpm = rpm
+ self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
+
+ def get_minute_slot(self):
+ current_time = time.time()
+ dt_object = datetime.fromtimestamp(current_time)
+ total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
+ return total_minutes_since_midnight
+
+ async def wait(self, silent=False):
+ current = time.time()
+ dt_object = datetime.fromtimestamp(current)
+ minute_slot = self.get_minute_slot()
+
+ if self.record['rpm_slot'] == minute_slot:
+ # check RPM exceed
+ if self.record['counter'] >= self.rpm:
+ # wait until next minute
+ next_minute = dt_object.replace(
+ second=0, microsecond=0) + timedelta(minutes=1)
+ _next = next_minute.timestamp()
+ sleep_time = abs(_next - current)
+ if not silent:
+ logger.info('RPM sleep %s', sleep_time)
+ await asyncio.sleep(sleep_time)
+
+ self.record = {
+ 'rpm_slot': self.get_minute_slot(),
+ 'counter': 0
+ }
+ else:
+ self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
+ self.record['counter'] += 1
+
+ if not silent:
+ logger.debug(self.record)
+
+
+class TPM:
+
+ def __init__(self, tpm: int = 20000):
+ self.tpm = tpm
+ self.record = {'tpm_slot': self.get_minute_slot(), 'counter': 0}
+
+ def get_minute_slot(self):
+ current_time = time.time()
+ dt_object = datetime.fromtimestamp(current_time)
+ total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
+ return total_minutes_since_midnight
+
+ async def wait(self, token_count, silent=False):
+ current = time.time()
+ dt_object = datetime.fromtimestamp(current)
+ minute_slot = self.get_minute_slot()
+
+ # get next slot, skip
+ if self.record['tpm_slot'] != minute_slot:
+ self.record = {'tpm_slot': minute_slot, 'counter': token_count}
+ return
+
+ # check RPM exceed
+ self.record['counter'] += token_count
+ if self.record['counter'] > self.tpm:
+ # wait until next minute
+ next_minute = dt_object.replace(
+ second=0, microsecond=0) + timedelta(minutes=1)
+ _next = next_minute.timestamp()
+ sleep_time = abs(_next - current)
+ logger.info('TPM sleep %s', sleep_time)
+ await asyncio.sleep(sleep_time)
+
+ self.record = {
+ 'tpm_slot': self.get_minute_slot(),
+ 'counter': token_count
+ }
+
+ if not silent:
+ logger.debug(self.record)
diff --git a/hf-repo/graphgen/models/llm/openai_model.py b/hf-repo/graphgen/models/llm/openai_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c04432f1502eb80d90cfa7cd50cd1ddc622e3a5
--- /dev/null
+++ b/hf-repo/graphgen/models/llm/openai_model.py
@@ -0,0 +1,155 @@
+import math
+import re
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional
+
+import openai
+from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
+from tenacity import (
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+from graphgen.models.llm.limitter import RPM, TPM
+from graphgen.models.llm.tokenizer import Tokenizer
+from graphgen.models.llm.topk_token_model import Token, TopkTokenModel
+
+
+def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
+ token_logprobs = response.choices[0].logprobs.content
+ tokens = []
+ for token_prob in token_logprobs:
+ prob = math.exp(token_prob.logprob)
+ candidate_tokens = [
+ Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs
+ ]
+ token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
+ tokens.append(token)
+ return tokens
+
+
+def filter_think_tags(text: str) -> str:
+ """
+ Remove tags from the text.
+ If the text contains and , it removes everything between them and the tags themselves.
+ """
+ think_pattern = re.compile(r".*?", re.DOTALL)
+ filtered_text = think_pattern.sub("", text).strip()
+ return filtered_text if filtered_text else text.strip()
+
+
+@dataclass
+class OpenAIModel(TopkTokenModel):
+ model_name: str = "gpt-4o-mini"
+ api_key: str = None
+ base_url: str = None
+
+ system_prompt: str = ""
+ json_mode: bool = False
+ seed: int = None
+
+ token_usage: list = field(default_factory=list)
+ request_limit: bool = False
+ rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
+ tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
+
+ tokenizer_instance: Tokenizer = field(default_factory=Tokenizer)
+
+ def __post_init__(self):
+ assert self.api_key is not None, "Please provide api key to access openai api."
+ self.client = AsyncOpenAI(
+ api_key=self.api_key or "dummy", base_url=self.base_url
+ )
+
+ def _pre_generate(self, text: str, history: List[str]) -> Dict:
+ kwargs = {
+ "temperature": self.temperature,
+ "top_p": self.topp,
+ "max_tokens": self.max_tokens,
+ }
+ if self.seed:
+ kwargs["seed"] = self.seed
+ if self.json_mode:
+ kwargs["response_format"] = {"type": "json_object"}
+
+ messages = []
+ if self.system_prompt:
+ messages.append({"role": "system", "content": self.system_prompt})
+ messages.append({"role": "user", "content": text})
+
+ if history:
+ assert len(history) % 2 == 0, "History should have even number of elements."
+ messages = history + messages
+
+ kwargs["messages"] = messages
+ return kwargs
+
+ @retry(
+ stop=stop_after_attempt(5),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+ )
+ async def generate_topk_per_token(
+ self, text: str, history: Optional[List[str]] = None
+ ) -> List[Token]:
+ kwargs = self._pre_generate(text, history)
+ if self.topk_per_token > 0:
+ kwargs["logprobs"] = True
+ kwargs["top_logprobs"] = self.topk_per_token
+
+ # Limit max_tokens to 1 to avoid long completions
+ kwargs["max_tokens"] = 1
+
+ completion = await self.client.chat.completions.create( # pylint: disable=E1125
+ model=self.model_name, **kwargs
+ )
+
+ tokens = get_top_response_tokens(completion)
+
+ return tokens
+
+ @retry(
+ stop=stop_after_attempt(5),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type(
+ (RateLimitError, APIConnectionError, APITimeoutError)
+ ),
+ )
+ async def generate_answer(
+ self, text: str, history: Optional[List[str]] = None, temperature: int = 0
+ ) -> str:
+ kwargs = self._pre_generate(text, history)
+ kwargs["temperature"] = temperature
+
+ prompt_tokens = 0
+ for message in kwargs["messages"]:
+ prompt_tokens += len(
+ self.tokenizer_instance.encode_string(message["content"])
+ )
+ estimated_tokens = prompt_tokens + kwargs["max_tokens"]
+
+ if self.request_limit:
+ await self.rpm.wait(silent=True)
+ await self.tpm.wait(estimated_tokens, silent=True)
+
+ completion = await self.client.chat.completions.create( # pylint: disable=E1125
+ model=self.model_name, **kwargs
+ )
+ if hasattr(completion, "usage"):
+ self.token_usage.append(
+ {
+ "prompt_tokens": completion.usage.prompt_tokens,
+ "completion_tokens": completion.usage.completion_tokens,
+ "total_tokens": completion.usage.total_tokens,
+ }
+ )
+ return filter_think_tags(completion.choices[0].message.content)
+
+ async def generate_inputs_prob(
+ self, text: str, history: Optional[List[str]] = None
+ ) -> List[Token]:
+ raise NotImplementedError
diff --git a/hf-repo/graphgen/models/llm/tokenizer.py b/hf-repo/graphgen/models/llm/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a1c4b2206a6980fcd070defe00f7e5a339a1ead
--- /dev/null
+++ b/hf-repo/graphgen/models/llm/tokenizer.py
@@ -0,0 +1,73 @@
+from dataclasses import dataclass
+from typing import List
+import tiktoken
+
+try:
+ from transformers import AutoTokenizer
+ TRANSFORMERS_AVAILABLE = True
+except ImportError:
+ AutoTokenizer = None
+ TRANSFORMERS_AVAILABLE = False
+
+
+def get_tokenizer(tokenizer_name: str = "cl100k_base"):
+ """
+ Get a tokenizer instance by name.
+
+ :param tokenizer_name: tokenizer name, tiktoken encoding name or Hugging Face model name
+ :return: tokenizer instance
+ """
+ if tokenizer_name in tiktoken.list_encoding_names():
+ return tiktoken.get_encoding(tokenizer_name)
+ if TRANSFORMERS_AVAILABLE:
+ try:
+ return AutoTokenizer.from_pretrained(tokenizer_name)
+ except Exception as e:
+ raise ValueError(f"Failed to load tokenizer from Hugging Face: {e}") from e
+ else:
+ raise ValueError("Hugging Face Transformers is not available, please install it first.")
+
+@dataclass
+class Tokenizer:
+ model_name: str = "cl100k_base"
+
+ def __post_init__(self):
+ self.tokenizer = get_tokenizer(self.model_name)
+
+ def encode_string(self, text: str) -> List[int]:
+ """
+ Encode text to tokens
+
+ :param text
+ :return: tokens
+ """
+ return self.tokenizer.encode(text)
+
+ def decode_tokens(self, tokens: List[int]) -> str:
+ """
+ Decode tokens to text
+
+ :param tokens
+ :return: text
+ """
+ return self.tokenizer.decode(tokens)
+
+ def chunk_by_token_size(
+ self, content: str, overlap_token_size=128, max_token_size=1024
+ ):
+ tokens = self.encode_string(content)
+ results = []
+ for index, start in enumerate(
+ range(0, len(tokens), max_token_size - overlap_token_size)
+ ):
+ chunk_content = self.decode_tokens(
+ tokens[start : start + max_token_size]
+ )
+ results.append(
+ {
+ "tokens": min(max_token_size, len(tokens) - start),
+ "content": chunk_content.strip(),
+ "chunk_order_index": index,
+ }
+ )
+ return results
diff --git a/hf-repo/graphgen/models/llm/topk_token_model.py b/hf-repo/graphgen/models/llm/topk_token_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7595cb1c9810f1c2d977d05b40228ef2a82b9f0
--- /dev/null
+++ b/hf-repo/graphgen/models/llm/topk_token_model.py
@@ -0,0 +1,48 @@
+import math
+from dataclasses import dataclass, field
+from typing import List, Union, Optional
+
+
+@dataclass
+class Token:
+ text: str
+ prob: float
+ top_candidates: List = field(default_factory=list)
+ ppl: Union[float, None] = field(default=None)
+
+ @property
+ def logprob(self) -> float:
+ return math.log(self.prob)
+
+
+@dataclass
+class TopkTokenModel:
+ do_sample: bool = False
+ temperature: float = 0
+ max_tokens: int = 4096
+ repetition_penalty: float = 1.05
+ num_beams: int = 1
+ topk: int = 50
+ topp: float = 0.95
+
+ topk_per_token: int = 5 # number of topk tokens to generate for each token
+
+ async def generate_topk_per_token(self, text: str) -> List[Token]:
+ """
+ Generate prob, text and candidates for each token of the model's output.
+ This function is used to visualize the inference process.
+ """
+ raise NotImplementedError
+
+ async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
+ """
+ Generate prob and text for each token of the input text.
+ This function is used to visualize the ppl.
+ """
+ raise NotImplementedError
+
+ async def generate_answer(self, text: str, history: Optional[List[str]] = None) -> str:
+ """
+ Generate answer from the model.
+ """
+ raise NotImplementedError
diff --git a/hf-repo/graphgen/models/search/__init__.py b/hf-repo/graphgen/models/search/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hf-repo/graphgen/models/storage/__init__.py b/hf-repo/graphgen/models/storage/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hf-repo/graphgen/models/storage/base_storage.py b/hf-repo/graphgen/models/storage/base_storage.py
new file mode 100644
index 0000000000000000000000000000000000000000..c09df074c0b1199cc03ec1fdbe1b7b297aa88537
--- /dev/null
+++ b/hf-repo/graphgen/models/storage/base_storage.py
@@ -0,0 +1,115 @@
+from dataclasses import dataclass
+from typing import Generic, TypeVar, Union
+
+from graphgen.models.embed.embedding import EmbeddingFunc
+
+T = TypeVar("T")
+
+
+@dataclass
+class StorageNameSpace:
+ working_dir: str = None
+ namespace: str = None
+
+ async def index_done_callback(self):
+ """commit the storage operations after indexing"""
+
+ async def query_done_callback(self):
+ """commit the storage operations after querying"""
+
+
+@dataclass
+class BaseListStorage(Generic[T], StorageNameSpace):
+ async def all_items(self) -> list[T]:
+ raise NotImplementedError
+
+ async def get_by_index(self, index: int) -> Union[T, None]:
+ raise NotImplementedError
+
+ async def append(self, data: T):
+ raise NotImplementedError
+
+ async def upsert(self, data: list[T]):
+ raise NotImplementedError
+
+ async def drop(self):
+ raise NotImplementedError
+
+
+@dataclass
+class BaseKVStorage(Generic[T], StorageNameSpace):
+ async def all_keys(self) -> list[str]:
+ raise NotImplementedError
+
+ async def get_by_id(self, id: str) -> Union[T, None]:
+ raise NotImplementedError
+
+ async def get_by_ids(
+ self, ids: list[str], fields: Union[set[str], None] = None
+ ) -> list[Union[T, None]]:
+ raise NotImplementedError
+
+ async def filter_keys(self, data: list[str]) -> set[str]:
+ """return un-exist keys"""
+ raise NotImplementedError
+
+ async def upsert(self, data: dict[str, T]):
+ raise NotImplementedError
+
+ async def drop(self):
+ raise NotImplementedError
+
+
+@dataclass
+class BaseGraphStorage(StorageNameSpace):
+ embedding_func: EmbeddingFunc = None
+
+ async def has_node(self, node_id: str) -> bool:
+ raise NotImplementedError
+
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
+ raise NotImplementedError
+
+ async def node_degree(self, node_id: str) -> int:
+ raise NotImplementedError
+
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
+ raise NotImplementedError
+
+ async def get_node(self, node_id: str) -> Union[dict, None]:
+ raise NotImplementedError
+
+ async def update_node(self, node_id: str, node_data: dict[str, str]):
+ raise NotImplementedError
+
+ async def get_all_nodes(self) -> Union[list[dict], None]:
+ raise NotImplementedError
+
+ async def get_edge(
+ self, source_node_id: str, target_node_id: str
+ ) -> Union[dict, None]:
+ raise NotImplementedError
+
+ async def update_edge(
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
+ ):
+ raise NotImplementedError
+
+ async def get_all_edges(self) -> Union[list[dict], None]:
+ raise NotImplementedError
+
+ async def get_node_edges(
+ self, source_node_id: str
+ ) -> Union[list[tuple[str, str]], None]:
+ raise NotImplementedError
+
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
+ raise NotImplementedError
+
+ async def upsert_edge(
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
+ ):
+ raise NotImplementedError
+
+ async def delete_node(self, node_id: str):
+ raise NotImplementedError
diff --git a/hf-repo/graphgen/models/storage/json_storage.py b/hf-repo/graphgen/models/storage/json_storage.py
new file mode 100644
index 0000000000000000000000000000000000000000..b61572f51cdce8f1e173eba8db8fdf42a7037fff
--- /dev/null
+++ b/hf-repo/graphgen/models/storage/json_storage.py
@@ -0,0 +1,87 @@
+import os
+from dataclasses import dataclass
+
+from graphgen.models.storage.base_storage import BaseKVStorage, BaseListStorage
+from graphgen.utils import load_json, logger, write_json
+
+
+@dataclass
+class JsonKVStorage(BaseKVStorage):
+ _data: dict[str, str] = None
+
+ def __post_init__(self):
+ self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
+ self._data = load_json(self._file_name) or {}
+ logger.info("Load KV %s with %d data", self.namespace, len(self._data))
+
+ @property
+ def data(self):
+ return self._data
+
+ async def all_keys(self) -> list[str]:
+ return list(self._data.keys())
+
+ async def index_done_callback(self):
+ write_json(self._data, self._file_name)
+
+ async def get_by_id(self, id):
+ return self._data.get(id, None)
+
+ async def get_by_ids(self, ids, fields=None) -> list:
+ if fields is None:
+ return [self._data.get(id, None) for id in ids]
+ return [
+ (
+ {k: v for k, v in self._data[id].items() if k in fields}
+ if self._data.get(id, None)
+ else None
+ )
+ for id in ids
+ ]
+
+ async def filter_keys(self, data: list[str]) -> set[str]:
+ return {s for s in data if s not in self._data}
+
+ async def upsert(self, data: dict):
+ left_data = {k: v for k, v in data.items() if k not in self._data}
+ self._data.update(left_data)
+ return left_data
+
+ async def drop(self):
+ self._data = {}
+
+
+@dataclass
+class JsonListStorage(BaseListStorage):
+ _data: list = None
+
+ def __post_init__(self):
+ self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
+ self._data = load_json(self._file_name) or []
+ logger.info("Load List %s with %d data", self.namespace, len(self._data))
+
+ @property
+ def data(self):
+ return self._data
+
+ async def all_items(self) -> list:
+ return self._data
+
+ async def index_done_callback(self):
+ write_json(self._data, self._file_name)
+
+ async def get_by_index(self, index: int):
+ if index < 0 or index >= len(self._data):
+ return None
+ return self._data[index]
+
+ async def append(self, data):
+ self._data.append(data)
+
+ async def upsert(self, data: list):
+ left_data = [d for d in data if d not in self._data]
+ self._data.extend(left_data)
+ return left_data
+
+ async def drop(self):
+ self._data = []
diff --git a/hf-repo/graphgen/models/storage/networkx_storage.py b/hf-repo/graphgen/models/storage/networkx_storage.py
new file mode 100644
index 0000000000000000000000000000000000000000..92643760708d6c62c86896baee8b4d3d7c9fe3e8
--- /dev/null
+++ b/hf-repo/graphgen/models/storage/networkx_storage.py
@@ -0,0 +1,159 @@
+import os
+import html
+from typing import Any, Union, cast, Optional
+from dataclasses import dataclass
+import networkx as nx
+
+from graphgen.utils import logger
+from .base_storage import BaseGraphStorage
+
+@dataclass
+class NetworkXStorage(BaseGraphStorage):
+ @staticmethod
+ def load_nx_graph(file_name) -> Optional[nx.Graph]:
+ if os.path.exists(file_name):
+ return nx.read_graphml(file_name)
+ return None
+
+ @staticmethod
+ def write_nx_graph(graph: nx.Graph, file_name):
+ logger.info("Writing graph with %d nodes, %d edges", graph.number_of_nodes(), graph.number_of_edges())
+ nx.write_graphml(graph, file_name)
+
+ @staticmethod
+ def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
+ """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
+ Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
+ """
+ from graspologic.utils import largest_connected_component
+
+ graph = graph.copy()
+ graph = cast(nx.Graph, largest_connected_component(graph))
+ node_mapping = {
+ node: html.unescape(node.upper().strip()) for node in graph.nodes()
+ } # type: ignore
+ graph = nx.relabel_nodes(graph, node_mapping)
+ return NetworkXStorage._stabilize_graph(graph)
+
+ @staticmethod
+ def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
+ """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
+ Ensure an undirected graph with the same relationships will always be read the same way.
+ 通过对节点和边进行排序来实现
+ """
+ fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
+
+ sorted_nodes = graph.nodes(data=True)
+ sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
+
+ fixed_graph.add_nodes_from(sorted_nodes)
+ edges = list(graph.edges(data=True))
+
+ if not graph.is_directed():
+
+ def _sort_source_target(edge):
+ source, target, edge_data = edge
+ if source > target:
+ source, target = target, source
+ return source, target, edge_data
+
+ edges = [_sort_source_target(edge) for edge in edges]
+
+ def _get_edge_key(source: Any, target: Any) -> str:
+ return f"{source} -> {target}"
+
+ edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
+
+ fixed_graph.add_edges_from(edges)
+ return fixed_graph
+
+ def __post_init__(self):
+ """
+ 如果图文件存在,则加载图文件,否则创建一个新图
+ """
+ self._graphml_xml_file = os.path.join(
+ self.working_dir, f"{self.namespace}.graphml"
+ )
+ preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
+ if preloaded_graph is not None:
+ logger.info(
+ "Loaded graph from %s with %d nodes, %d edges", self._graphml_xml_file,
+ preloaded_graph.number_of_nodes(), preloaded_graph.number_of_edges()
+ )
+ self._graph = preloaded_graph or nx.Graph()
+
+ async def index_done_callback(self):
+ NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
+
+ async def has_node(self, node_id: str) -> bool:
+ return self._graph.has_node(node_id)
+
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
+ return self._graph.has_edge(source_node_id, target_node_id)
+
+ async def get_node(self, node_id: str) -> Union[dict, None]:
+ return self._graph.nodes.get(node_id)
+
+ async def get_all_nodes(self) -> Union[list[dict], None]:
+ return self._graph.nodes(data=True)
+
+ async def node_degree(self, node_id: str) -> int:
+ return self._graph.degree(node_id)
+
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
+ return self._graph.degree(src_id) + self._graph.degree(tgt_id)
+
+ async def get_edge(
+ self, source_node_id: str, target_node_id: str
+ ) -> Union[dict, None]:
+ return self._graph.edges.get((source_node_id, target_node_id))
+
+ async def get_all_edges(self) -> Union[list[dict], None]:
+ return self._graph.edges(data=True)
+
+ async def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
+ if self._graph.has_node(source_node_id):
+ return list(self._graph.edges(source_node_id, data=True))
+ return None
+
+ async def get_graph(self) -> nx.Graph:
+ return self._graph
+
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
+ self._graph.add_node(node_id, **node_data)
+
+ async def update_node(self, node_id: str, node_data: dict[str, str]):
+ if self._graph.has_node(node_id):
+ self._graph.nodes[node_id].update(node_data)
+ else:
+ logger.warning("Node %s not found in the graph for update.", node_id)
+
+ async def upsert_edge(
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
+ ):
+ self._graph.add_edge(source_node_id, target_node_id, **edge_data)
+
+ async def update_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
+ if self._graph.has_edge(source_node_id, target_node_id):
+ self._graph.edges[(source_node_id, target_node_id)].update(edge_data)
+ else:
+ logger.warning("Edge %s -> %s not found in the graph for update.", source_node_id, target_node_id)
+
+ async def delete_node(self, node_id: str):
+ """
+ Delete a node from the graph based on the specified node_id.
+
+ :param node_id: The node_id to delete
+ """
+ if self._graph.has_node(node_id):
+ self._graph.remove_node(node_id)
+ logger.info("Node %s deleted from the graph.", node_id)
+ else:
+ logger.warning("Node %s not found in the graph for deletion.", node_id)
+
+ async def clear(self):
+ """
+ Clear the graph by removing all nodes and edges.
+ """
+ self._graph.clear()
+ logger.info("Graph %s cleared.", self.namespace)
diff --git a/hf-repo/graphgen/models/strategy/__init__.py b/hf-repo/graphgen/models/strategy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hf-repo/graphgen/models/strategy/base_strategy.py b/hf-repo/graphgen/models/strategy/base_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..70e0cc54af1b152d28f226c697e5a805a57018ea
--- /dev/null
+++ b/hf-repo/graphgen/models/strategy/base_strategy.py
@@ -0,0 +1,5 @@
+from dataclasses import dataclass
+
+@dataclass
+class BaseStrategy:
+ pass
diff --git a/hf-repo/graphgen/models/strategy/travserse_strategy.py b/hf-repo/graphgen/models/strategy/travserse_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..06882c5f882d1bb152cc22625b185494100c2fc3
--- /dev/null
+++ b/hf-repo/graphgen/models/strategy/travserse_strategy.py
@@ -0,0 +1,30 @@
+from dataclasses import dataclass, fields
+
+from graphgen.models.strategy.base_strategy import BaseStrategy
+
+
+@dataclass
+class TraverseStrategy(BaseStrategy):
+ # 生成的QA形式:原子、多跳、聚合型
+ qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated"
+ # 最大边数和最大token数方法中选择一个生效
+ expand_method: str = "max_tokens" # "max_width" or "max_tokens"
+ # 单向拓展还是双向拓展
+ bidirectional: bool = True
+ # 每个方向拓展的最大边数
+ max_extra_edges: int = 5
+ # 最长token数
+ max_tokens: int = 256
+ # 每个方向拓展的最大深度
+ max_depth: int = 2
+ # 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合)
+ edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random"
+ # 孤立节点的处理策略
+ isolated_node_strategy: str = "add" # "add" or "ignore"
+ loss_strategy: str = "only_edge" # only_edge, both
+
+ def to_yaml(self):
+ strategy_dict = {}
+ for f in fields(self):
+ strategy_dict[f.name] = getattr(self, f.name)
+ return {"traverse_strategy": strategy_dict}
diff --git a/hf-repo/graphgen/models/text/__init__.py b/hf-repo/graphgen/models/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hf-repo/graphgen/models/text/chunk.py b/hf-repo/graphgen/models/text/chunk.py
new file mode 100644
index 0000000000000000000000000000000000000000..9678949fe170d9a1588f2b0911701c703a062b55
--- /dev/null
+++ b/hf-repo/graphgen/models/text/chunk.py
@@ -0,0 +1,7 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class Chunk:
+ id : str
+ content: str
diff --git a/hf-repo/graphgen/models/text/text_pair.py b/hf-repo/graphgen/models/text/text_pair.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a971f1ce9fc2d1ecdb82f61eae3711d8b76ebc
--- /dev/null
+++ b/hf-repo/graphgen/models/text/text_pair.py
@@ -0,0 +1,9 @@
+from dataclasses import dataclass
+
+@dataclass
+class TextPair:
+ """
+ A pair of input data.
+ """
+ question: str
+ answer: str
diff --git a/hf-repo/graphgen/operators/__init__.py b/hf-repo/graphgen/operators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f74e013a0ebbded2e2d8f94d0e97249bf4744b6c
--- /dev/null
+++ b/hf-repo/graphgen/operators/__init__.py
@@ -0,0 +1,22 @@
+from graphgen.operators.generate.generate_cot import generate_cot
+from graphgen.operators.kg.extract_kg import extract_kg
+from graphgen.operators.search.search_all import search_all
+
+from .judge import judge_statement
+from .quiz import quiz
+from .traverse_graph import (
+ traverse_graph_atomically,
+ traverse_graph_by_edge,
+ traverse_graph_for_multi_hop,
+)
+
+__all__ = [
+ "extract_kg",
+ "quiz",
+ "judge_statement",
+ "search_all",
+ "traverse_graph_by_edge",
+ "traverse_graph_atomically",
+ "traverse_graph_for_multi_hop",
+ "generate_cot",
+]
diff --git a/hf-repo/graphgen/operators/judge.py b/hf-repo/graphgen/operators/judge.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e9d33ebdd88936d06fdb69d08e52611a1fb647
--- /dev/null
+++ b/hf-repo/graphgen/operators/judge.py
@@ -0,0 +1,149 @@
+import asyncio
+import math
+
+from tqdm.asyncio import tqdm as tqdm_async
+
+from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIModel
+from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
+from graphgen.utils import logger, yes_no_loss_entropy
+
+
+async def judge_statement( # pylint: disable=too-many-statements
+ trainee_llm_client: OpenAIModel,
+ graph_storage: NetworkXStorage,
+ rephrase_storage: JsonKVStorage,
+ re_judge: bool = False,
+ max_concurrent: int = 1000,
+) -> NetworkXStorage:
+ """
+ Get all edges and nodes and judge them
+
+ :param trainee_llm_client: judge the statements to get comprehension loss
+ :param graph_storage: graph storage instance
+ :param rephrase_storage: rephrase storage instance
+ :param re_judge: re-judge the relations
+ :param max_concurrent: max concurrent
+ :return:
+ """
+
+ semaphore = asyncio.Semaphore(max_concurrent)
+
+ async def _judge_single_relation(
+ edge: tuple,
+ ):
+ async with semaphore:
+ source_id = edge[0]
+ target_id = edge[1]
+ edge_data = edge[2]
+
+ if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
+ logger.info(
+ "Edge %s -> %s already judged, loss: %s, skip",
+ source_id,
+ target_id,
+ edge_data["loss"],
+ )
+ return source_id, target_id, edge_data
+
+ description = edge_data["description"]
+
+ try:
+ descriptions = await rephrase_storage.get_by_id(description)
+ assert descriptions is not None
+
+ judgements = []
+ gts = [gt for _, gt in descriptions]
+ for description, gt in descriptions:
+ judgement = await trainee_llm_client.generate_topk_per_token(
+ STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
+ statement=description
+ )
+ )
+ judgements.append(judgement[0].top_candidates)
+
+ loss = yes_no_loss_entropy(judgements, gts)
+
+ logger.info(
+ "Edge %s -> %s description: %s loss: %s",
+ source_id,
+ target_id,
+ description,
+ loss,
+ )
+
+ edge_data["loss"] = loss
+ except Exception as e: # pylint: disable=broad-except
+ logger.error(
+ "Error in judging relation %s -> %s: %s", source_id, target_id, e
+ )
+ logger.info("Use default loss 0.1")
+ edge_data["loss"] = -math.log(0.1)
+
+ await graph_storage.update_edge(source_id, target_id, edge_data)
+ return source_id, target_id, edge_data
+
+ edges = await graph_storage.get_all_edges()
+
+ results = []
+ for result in tqdm_async(
+ asyncio.as_completed([_judge_single_relation(edge) for edge in edges]),
+ total=len(edges),
+ desc="Judging relations",
+ ):
+ results.append(await result)
+
+ async def _judge_single_entity(
+ node: tuple,
+ ):
+ async with semaphore:
+ node_id = node[0]
+ node_data = node[1]
+
+ if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
+ logger.info(
+ "Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
+ )
+ return node_id, node_data
+
+ description = node_data["description"]
+
+ try:
+ descriptions = await rephrase_storage.get_by_id(description)
+ assert descriptions is not None
+
+ judgements = []
+ gts = [gt for _, gt in descriptions]
+ for description, gt in descriptions:
+ judgement = await trainee_llm_client.generate_topk_per_token(
+ STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
+ statement=description
+ )
+ )
+ judgements.append(judgement[0].top_candidates)
+
+ loss = yes_no_loss_entropy(judgements, gts)
+
+ logger.info(
+ "Node %s description: %s loss: %s", node_id, description, loss
+ )
+
+ node_data["loss"] = loss
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error in judging entity %s: %s", node_id, e)
+ logger.info("Use default loss 0.1")
+ node_data["loss"] = -math.log(0.1)
+
+ await graph_storage.update_node(node_id, node_data)
+ return node_id, node_data
+
+ nodes = await graph_storage.get_all_nodes()
+
+ results = []
+ for result in tqdm_async(
+ asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
+ total=len(nodes),
+ desc="Judging entities",
+ ):
+ results.append(await result)
+
+ return graph_storage
diff --git a/hf-repo/graphgen/operators/quiz.py b/hf-repo/graphgen/operators/quiz.py
new file mode 100644
index 0000000000000000000000000000000000000000..36edddb100c1ccca7b764a199d48e539fec3b757
--- /dev/null
+++ b/hf-repo/graphgen/operators/quiz.py
@@ -0,0 +1,109 @@
+import asyncio
+from collections import defaultdict
+
+from tqdm.asyncio import tqdm as tqdm_async
+from graphgen.models import JsonKVStorage, OpenAIModel, NetworkXStorage
+from graphgen.utils import logger, detect_main_language
+from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
+
+
+async def quiz(
+ synth_llm_client: OpenAIModel,
+ graph_storage: NetworkXStorage,
+ rephrase_storage: JsonKVStorage,
+ max_samples: int = 1,
+ max_concurrent: int = 1000) -> JsonKVStorage:
+ """
+ Get all edges and quiz them
+
+ :param synth_llm_client: generate statements
+ :param graph_storage: graph storage instance
+ :param rephrase_storage: rephrase storage instance
+ :param max_samples: max samples for each edge
+ :param max_concurrent: max concurrent
+ :return:
+ """
+
+ semaphore = asyncio.Semaphore(max_concurrent)
+
+ async def _process_single_quiz(
+ des: str,
+ prompt: str,
+ gt: str
+ ):
+ async with semaphore:
+ try:
+ # 如果在rephrase_storage中已经存在,直接取出
+ descriptions = await rephrase_storage.get_by_id(des)
+ if descriptions:
+ return None
+
+ new_description = await synth_llm_client.generate_answer(
+ prompt,
+ temperature=1
+ )
+ return {des: [(new_description, gt)]}
+
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error when quizzing description %s: %s", des, e)
+ return None
+
+
+ edges = await graph_storage.get_all_edges()
+ nodes = await graph_storage.get_all_nodes()
+
+ results = defaultdict(list)
+ tasks = []
+ for edge in edges:
+ edge_data = edge[2]
+
+ description = edge_data["description"]
+ language = "English" if detect_main_language(description) == "en" else "Chinese"
+
+ results[description] = [(description, 'yes')]
+
+ for i in range(max_samples):
+ if i > 0:
+ tasks.append(
+ _process_single_quiz(description,
+ DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
+ input_sentence=description), 'yes')
+ )
+ tasks.append(_process_single_quiz(description,
+ DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
+ input_sentence=description), 'no'))
+
+ for node in nodes:
+ node_data = node[1]
+ description = node_data["description"]
+ language = "English" if detect_main_language(description) == "en" else "Chinese"
+
+ results[description] = [(description, 'yes')]
+
+ for i in range(max_samples):
+ if i > 0:
+ tasks.append(
+ _process_single_quiz(description,
+ DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
+ input_sentence=description), 'yes')
+ )
+ tasks.append(_process_single_quiz(description,
+ DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
+ input_sentence=description), 'no'))
+
+ for result in tqdm_async(
+ asyncio.as_completed(tasks),
+ total=len(tasks),
+ desc="Quizzing descriptions"
+ ):
+ new_result = await result
+ if new_result:
+ for key, value in new_result.items():
+ results[key].extend(value)
+
+ for key, value in results.items():
+ results[key] = list(set(value))
+ await rephrase_storage.upsert({key: results[key]})
+
+
+ return rephrase_storage
diff --git a/hf-repo/graphgen/operators/traverse_graph.py b/hf-repo/graphgen/operators/traverse_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..da1b668544806ee2faffb49c97fd72d90a74935f
--- /dev/null
+++ b/hf-repo/graphgen/operators/traverse_graph.py
@@ -0,0 +1,536 @@
+import asyncio
+
+import gradio as gr
+from tqdm.asyncio import tqdm as tqdm_async
+
+from graphgen.models import (
+ JsonKVStorage,
+ NetworkXStorage,
+ OpenAIModel,
+ Tokenizer,
+ TraverseStrategy,
+)
+from graphgen.operators.kg.split_kg import get_batches_with_strategy
+from graphgen.templates import (
+ ANSWER_REPHRASING_PROMPT,
+ MULTI_HOP_GENERATION_PROMPT,
+ QUESTION_GENERATION_PROMPT,
+)
+from graphgen.utils import compute_content_hash, detect_main_language, logger
+
+
+async def _pre_tokenize(
+ graph_storage: NetworkXStorage, tokenizer: Tokenizer, edges: list, nodes: list
+) -> tuple:
+
+ sem = asyncio.Semaphore(1000)
+
+ async def handle_edge(edge: tuple) -> tuple:
+ async with sem:
+ if "length" not in edge[2]:
+ edge[2]["length"] = len(
+ await asyncio.get_event_loop().run_in_executor(
+ None, tokenizer.encode_string, edge[2]["description"]
+ )
+ )
+ return edge
+
+ async def handle_node(node: dict) -> dict:
+ async with sem:
+ if "length" not in node[1]:
+ node[1]["length"] = len(
+ await asyncio.get_event_loop().run_in_executor(
+ None, tokenizer.encode_string, node[1]["description"]
+ )
+ )
+ return node
+
+ new_edges = []
+ new_nodes = []
+
+ for result in tqdm_async(
+ asyncio.as_completed([handle_edge(edge) for edge in edges]),
+ total=len(edges),
+ desc="Pre-tokenizing edges",
+ ):
+ new_edge = await result
+ await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2])
+ new_edges.append(new_edge)
+
+ for result in tqdm_async(
+ asyncio.as_completed([handle_node(node) for node in nodes]),
+ total=len(nodes),
+ desc="Pre-tokenizing nodes",
+ ):
+ new_node = await result
+ await graph_storage.update_node(new_node[0], new_node[1])
+ new_nodes.append(new_node)
+
+ await graph_storage.index_done_callback()
+ return new_edges, new_nodes
+
+
+async def _construct_rephrasing_prompt(
+ _process_nodes: list,
+ _process_edges: list,
+ text_chunks_storage: JsonKVStorage,
+ add_context: bool = False,
+) -> str:
+ entities = [
+ f"{_process_node['node_id']}: {_process_node['description']}"
+ for _process_node in _process_nodes
+ ]
+ relations = [
+ f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
+ for _process_edge in _process_edges
+ ]
+
+ entities_str = "\n".join(
+ [f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
+ )
+ relations_str = "\n".join(
+ [f"{index + 1}. {relation}" for index, relation in enumerate(relations)]
+ )
+ language = (
+ "Chinese"
+ if detect_main_language(entities_str + relations_str) == "zh"
+ else "English"
+ )
+
+ if add_context:
+ original_ids = [
+ node["source_id"].split("")[0] for node in _process_nodes
+ ] + [edge[2]["source_id"].split("")[0] for edge in _process_edges]
+
+ original_ids = list(set(original_ids))
+ original_text = await text_chunks_storage.get_by_ids(original_ids)
+ original_text = "\n".join(
+ [
+ f"{index + 1}. {text['content']}"
+ for index, text in enumerate(original_text)
+ ]
+ )
+
+ prompt = ANSWER_REPHRASING_PROMPT[language]["CONTEXT_TEMPLATE"].format(
+ language=language,
+ original_text=original_text,
+ entities=entities_str,
+ relationships=relations_str,
+ )
+ return prompt
+
+ prompt = ANSWER_REPHRASING_PROMPT[language]["TEMPLATE"].format(
+ language=language, entities=entities_str, relationships=relations_str
+ )
+ return prompt
+
+
+def get_average_loss(batch: tuple, loss_strategy: str) -> float:
+ try:
+ if loss_strategy == "only_edge":
+ return sum(edge[2]["loss"] for edge in batch[1]) / len(batch[1])
+ if loss_strategy == "both":
+ return sum(edge[2]["loss"] for edge in batch[1]) + sum(
+ node["loss"] for node in batch[0]
+ ) / (len(batch[0]) + len(batch[1]))
+ raise ValueError("Invalid loss strategy")
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error calculating average loss: %s", e)
+ return -1.0
+
+
+def _post_process_synthetic_data(data):
+ block = data.split("\n\n")
+ qas = []
+ for line in block:
+ if "Question:" in line and "Answer:" in line:
+ question = line.split("Question:")[1].split("Answer:")[0].strip()
+ answer = line.split("Answer:")[1].strip()
+ qas.append({"question": question, "answer": answer})
+ elif "问题:" in line and "答案:" in line:
+ question = line.split("问题:")[1].split("答案:")[0].strip()
+ answer = line.split("答案:")[1].strip()
+ qas.append({"question": question, "answer": answer})
+ elif "问题:" in line and "回答:" in line:
+ question = line.split("问题:")[1].split("回答:")[0].strip()
+ answer = line.split("回答:")[1].strip()
+ qas.append({"question": question, "answer": answer})
+ return qas
+
+
+async def traverse_graph_by_edge(
+ llm_client: OpenAIModel,
+ tokenizer: Tokenizer,
+ graph_storage: NetworkXStorage,
+ traverse_strategy: TraverseStrategy,
+ text_chunks_storage: JsonKVStorage,
+ progress_bar: gr.Progress = None,
+ max_concurrent: int = 1000,
+) -> dict:
+ """
+ Traverse the graph
+
+ :param llm_client
+ :param tokenizer
+ :param graph_storage
+ :param traverse_strategy
+ :param text_chunks_storage
+ :param progress_bar
+ :param max_concurrent
+ :return: question and answer
+ """
+
+ semaphore = asyncio.Semaphore(max_concurrent)
+
+ async def _process_nodes_and_edges(
+ _process_nodes: list,
+ _process_edges: list,
+ ) -> str:
+ prompt = await _construct_rephrasing_prompt(
+ _process_nodes, _process_edges, text_chunks_storage, add_context=False
+ )
+ context = await llm_client.generate_answer(prompt)
+
+ # post-process the context
+ if context.startswith("Rephrased Text:"):
+ context = context[len("Rephrased Text:") :].strip()
+ elif context.startswith("重述文本:"):
+ context = context[len("重述文本:") :].strip()
+
+ return context
+
+ async def _process_single_batch(
+ _process_batch: tuple, question_type: str = "single"
+ ) -> dict:
+ async with semaphore:
+ context = await _process_nodes_and_edges(
+ _process_batch[0],
+ _process_batch[1],
+ )
+
+ language = "Chinese" if detect_main_language(context) == "zh" else "English"
+ pre_length = sum(node["length"] for node in _process_batch[0]) + sum(
+ edge[2]["length"] for edge in _process_batch[1]
+ )
+
+ if question_type == "single":
+ question = await llm_client.generate_answer(
+ QUESTION_GENERATION_PROMPT[language]["SINGLE_TEMPLATE"].format(
+ answer=context
+ )
+ )
+ if question.startswith("Question:"):
+ question = question[len("Question:") :].strip()
+ elif question.startswith("问题:"):
+ question = question[len("问题:") :].strip()
+
+ logger.info(
+ "%d nodes and %d edges processed",
+ len(_process_batch[0]),
+ len(_process_batch[1]),
+ )
+ logger.info("Pre-length: %s", pre_length)
+ logger.info("Question: %s", question)
+ logger.info("Answer: %s", context)
+
+ return {
+ compute_content_hash(context): {
+ "question": question,
+ "answer": context,
+ "loss": get_average_loss(
+ _process_batch, traverse_strategy.loss_strategy
+ ),
+ }
+ }
+
+ content = await llm_client.generate_answer(
+ QUESTION_GENERATION_PROMPT[language]["MULTI_TEMPLATE"].format(
+ doc=context
+ )
+ )
+ qas = _post_process_synthetic_data(content)
+
+ if len(qas) == 0:
+ print(content)
+ logger.error(
+ "Error occurred while processing batch, question or answer is None"
+ )
+ return {}
+
+ final_results = {}
+ logger.info(
+ "%d nodes and %d edges processed",
+ len(_process_batch[0]),
+ len(_process_batch[1]),
+ )
+ logger.info("Pre-length: %s", pre_length)
+ for qa in qas:
+ logger.info("Question: %s", qa["question"])
+ logger.info("Answer: %s", qa["answer"])
+ final_results[compute_content_hash(qa["question"])] = {
+ "question": qa["question"],
+ "answer": qa["answer"],
+ "loss": get_average_loss(
+ _process_batch, traverse_strategy.loss_strategy
+ ),
+ }
+ return final_results
+
+ results = {}
+ edges = list(await graph_storage.get_all_edges())
+ nodes = list(await graph_storage.get_all_nodes())
+
+ edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
+
+ processing_batches = await get_batches_with_strategy(
+ nodes, edges, graph_storage, traverse_strategy
+ )
+
+ for result in tqdm_async(
+ asyncio.as_completed(
+ [_process_single_batch(batch) for batch in processing_batches]
+ ),
+ total=len(processing_batches),
+ desc="[4/4]Generating QAs",
+ ):
+ try:
+ if progress_bar is not None:
+ progress_bar(
+ len(results) / len(processing_batches), desc="[4/4]Generating QAs"
+ )
+ results.update(await result)
+ if progress_bar is not None and len(results) == len(processing_batches):
+ progress_bar(1, desc="[4/4]Generating QAs")
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error occurred while generating QA: %s", e)
+
+ return results
+
+
+async def traverse_graph_atomically(
+ llm_client: OpenAIModel,
+ tokenizer: Tokenizer,
+ graph_storage: NetworkXStorage,
+ traverse_strategy: TraverseStrategy,
+ text_chunks_storage: JsonKVStorage,
+ progress_bar: gr.Progress = None,
+ max_concurrent: int = 1000,
+) -> dict:
+ """
+ Traverse the graph atomicly
+
+ :param llm_client
+ :param tokenizer
+ :param graph_storage
+ :param traverse_strategy
+ :param text_chunks_storage
+ :param progress_bar
+ :param max_concurrent
+ :return: question and answer
+ """
+ assert traverse_strategy.qa_form == "atomic"
+
+ semaphore = asyncio.Semaphore(max_concurrent)
+
+ async def _generate_question(node_or_edge: tuple):
+ if len(node_or_edge) == 2:
+ des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
+ loss = node_or_edge[1]["loss"]
+ else:
+ des = node_or_edge[2]["description"]
+ loss = node_or_edge[2]["loss"]
+
+ async with semaphore:
+ try:
+ language = "Chinese" if detect_main_language(des) == "zh" else "English"
+
+ qa = await llm_client.generate_answer(
+ QUESTION_GENERATION_PROMPT[language]["SINGLE_QA_TEMPLATE"].format(
+ doc=des
+ )
+ )
+
+ if "Question:" in qa and "Answer:" in qa:
+ question = qa.split("Question:")[1].split("Answer:")[0].strip()
+ answer = qa.split("Answer:")[1].strip()
+ elif "问题:" in qa and "答案:" in qa:
+ question = qa.split("问题:")[1].split("答案:")[0].strip()
+ answer = qa.split("答案:")[1].strip()
+ else:
+ return {}
+
+ question = question.strip('"')
+ answer = answer.strip('"')
+
+ logger.info("Question: %s", question)
+ logger.info("Answer: %s", answer)
+ return {
+ compute_content_hash(question): {
+ "question": question,
+ "answer": answer,
+ "loss": loss,
+ }
+ }
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error occurred while generating question: %s", e)
+ return {}
+
+ results = {}
+ edges = list(await graph_storage.get_all_edges())
+ nodes = list(await graph_storage.get_all_nodes())
+
+ edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
+
+ tasks = []
+ for node in nodes:
+ if "" in node[1]["description"]:
+ description_list = node[1]["description"].split("")
+ for item in description_list:
+ tasks.append((node[0], {"description": item, "loss": node[1]["loss"]}))
+ else:
+ tasks.append((node[0], node[1]))
+ for edge in edges:
+ if "" in edge[2]["description"]:
+ description_list = edge[2]["description"].split("")
+ for item in description_list:
+ tasks.append(
+ (edge[0], edge[1], {"description": item, "loss": edge[2]["loss"]})
+ )
+ else:
+ tasks.append((edge[0], edge[1], edge[2]))
+
+ for result in tqdm_async(
+ asyncio.as_completed([_generate_question(task) for task in tasks]),
+ total=len(tasks),
+ desc="[4/4]Generating QAs",
+ ):
+ try:
+ if progress_bar is not None:
+ progress_bar(len(results) / len(tasks), desc="[4/4]Generating QAs")
+ results.update(await result)
+ if progress_bar is not None and len(results) == len(tasks):
+ progress_bar(1, desc="[4/4]Generating QAs")
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error occurred while generating QA: %s", e)
+ return results
+
+
+async def traverse_graph_for_multi_hop(
+ llm_client: OpenAIModel,
+ tokenizer: Tokenizer,
+ graph_storage: NetworkXStorage,
+ traverse_strategy: TraverseStrategy,
+ text_chunks_storage: JsonKVStorage,
+ progress_bar: gr.Progress = None,
+ max_concurrent: int = 1000,
+) -> dict:
+ """
+ Traverse the graph for multi-hop
+
+ :param llm_client
+ :param tokenizer
+ :param graph_storage
+ :param traverse_strategy
+ :param text_chunks_storage
+ :param progress_bar
+ :param max_concurrent
+ :return: question and answer
+ """
+ semaphore = asyncio.Semaphore(max_concurrent)
+
+ results = {}
+ edges = list(await graph_storage.get_all_edges())
+ nodes = list(await graph_storage.get_all_nodes())
+
+ edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
+
+ processing_batches = await get_batches_with_strategy(
+ nodes, edges, graph_storage, traverse_strategy
+ )
+
+ async def _process_single_batch(_process_batch: tuple) -> dict:
+ async with semaphore:
+ try:
+ language = (
+ "Chinese"
+ if detect_main_language(_process_batch[0][0]["description"]) == "zh"
+ else "English"
+ )
+
+ _process_nodes = _process_batch[0]
+ _process_edges = _process_batch[1]
+
+ entities = [
+ f"{_process_node['node_id']}: {_process_node['description']}"
+ for _process_node in _process_nodes
+ ]
+
+ relations = [
+ f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
+ for _process_edge in _process_edges
+ ]
+
+ entities_str = "\n".join(
+ [f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
+ )
+ relations_str = "\n".join(
+ [
+ f"{index + 1}. {relation}"
+ for index, relation in enumerate(relations)
+ ]
+ )
+
+ prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
+ entities=entities_str, relationships=relations_str
+ )
+
+ context = await llm_client.generate_answer(prompt)
+
+ # post-process the context
+ if "Question:" in context and "Answer:" in context:
+ question = context.split("Question:")[1].split("Answer:")[0].strip()
+ answer = context.split("Answer:")[1].strip()
+ elif "问题:" in context and "答案:" in context:
+ question = context.split("问题:")[1].split("答案:")[0].strip()
+ answer = context.split("答案:")[1].strip()
+ else:
+ return {}
+
+ question = question.strip('"')
+ answer = answer.strip('"')
+
+ logger.info("Question: %s", question)
+ logger.info("Answer: %s", answer)
+
+ return {
+ compute_content_hash(question): {
+ "question": question,
+ "answer": answer,
+ "loss": get_average_loss(
+ _process_batch, traverse_strategy.loss_strategy
+ ),
+ }
+ }
+
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error occurred while processing batch: %s", e)
+ return {}
+
+ async for result in tqdm_async(
+ asyncio.as_completed(
+ [_process_single_batch(batch) for batch in processing_batches]
+ ),
+ total=len(processing_batches),
+ desc="[4/4]Generating QAs",
+ ):
+ try:
+ if progress_bar is not None:
+ progress_bar(
+ len(results) / len(processing_batches), desc="[4/4]Generating QAs"
+ )
+ results.update(await result)
+ if progress_bar is not None and len(results) == len(processing_batches):
+ progress_bar(1, desc="[4/4]Generating QAs")
+ except Exception as e: # pylint: disable=broad-except
+ logger.error("Error occurred while generating QA: %s", e)
+ return results
diff --git a/hf-repo/graphgen/templates/__init__.py b/hf-repo/graphgen/templates/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3d1e9ed5dfd20f0f08cb6c39f40bb1794b80ca4
--- /dev/null
+++ b/hf-repo/graphgen/templates/__init__.py
@@ -0,0 +1,10 @@
+from .answer_rephrasing import ANSWER_REPHRASING_PROMPT
+from .community import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
+from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
+from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
+from .kg_extraction import KG_EXTRACTION_PROMPT
+from .kg_summarization import KG_SUMMARIZATION_PROMPT
+from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT
+from .question_generation import QUESTION_GENERATION_PROMPT
+from .search_judgement import SEARCH_JUDGEMENT_PROMPT
+from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
diff --git a/hf-repo/graphgen/templates/answer_rephrasing.py b/hf-repo/graphgen/templates/answer_rephrasing.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc988fa25edeedca98674268c9403c50f2ebb995
--- /dev/null
+++ b/hf-repo/graphgen/templates/answer_rephrasing.py
@@ -0,0 +1,213 @@
+TEMPLATE_CONTEXT_EN: str = """---Role---
+You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below. You may refer to the original text to assist in generating the rephrased version, but ensure that the final output text meets the requirements.
+Use {language} as output language.
+
+---Goal---
+To generate a version of the text that is rephrased and conveys the same meaning as the original entity and relationship descriptions, while:
+1. Following a clear logical flow and structure
+2. Establishing proper cause-and-effect relationships
+3. Ensuring temporal and sequential consistency
+4. Creating smooth transitions between ideas using conjunctions and appropriate linking words like "firstly," "however," "therefore," etc.
+
+---Instructions---
+1. Analyze the provided ENTITIES and RELATIONSHIPS carefully to identify:
+ - Key concepts and their hierarchies
+ - Temporal sequences and chronological order
+ - Cause-and-effect relationships
+ - Dependencies between different elements
+
+2. Organize the information in a logical sequence by:
+ - Starting with foundational concepts
+ - Building up to more complex relationships
+ - Grouping related ideas together
+ - Creating clear transitions between sections
+
+3. Rephrase the text while maintaining:
+ - Logical flow and progression
+ - Clear connections between ideas
+ - Proper context and background
+ - Coherent narrative structure
+
+4. Review and refine the text to ensure:
+ - Logical consistency throughout
+ - Clear cause-and-effect relationships
+
+################
+-ORIGINAL TEXT-
+################
+{original_text}
+
+################
+-ENTITIES-
+################
+{entities}
+
+################
+-RELATIONSHIPS-
+################
+{relationships}
+
+"""
+
+TEMPLATE_CONTEXT_ZH: str = """---角色---
+你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。你可以参考原始文本辅助生成,但需要确保最终输出的文本符合要求。
+使用{language}作为输出语言。
+
+---目标---
+生成文本的重述版本,使其传达与原始实体和关系描述相同的含义,同时:
+1. 遵循清晰的逻辑流和结构
+2. 建立适当的因果关系
+3. 确保时间和顺序的一致性
+4. 使用连词和适当的连接词(如"首先"、"然而"、"因此"等)创造流畅的过渡
+
+---说明---
+1. 仔细分析提供的实体和关系,以识别:
+ - 关键概念及其层级关系
+ - 时间序列和时间顺序
+ - 因果关系
+ - 不同元素之间的依赖关系
+2. 通过以下方式将信息组织成逻辑顺序:
+ - 从基础概念开始
+ - 逐步建立更复杂的关系
+ - 将相关的想法分组在一起
+ - 在各部分之间创建清晰的过渡
+3. 重述文本时保持:
+ - 逻辑流畅
+ - 概念之间的清晰联系
+ - 适当的上下文和背景
+ - 连贯的叙述结构
+4. 检查和完善文本以确保:
+ - 整体逻辑一致性
+ - 清晰的因果关系
+
+################
+-原始文本-
+################
+{original_text}
+
+################
+-实体-
+################
+{entities}
+
+################
+-关系-
+################
+{relationships}
+
+"""
+
+TEMPLATE_EN: str = """---Role---
+You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below.
+Use {language} as output language.
+
+---Goal---
+To generate a version of the text that is rephrased and conveys the same meaning as the original entity and relationship descriptions, while:
+1. Following a clear logical flow and structure
+2. Establishing proper cause-and-effect relationships
+3. Ensuring temporal and sequential consistency
+4. Creating smooth transitions between ideas using conjunctions and appropriate linking words like "firstly," "however," "therefore," etc.
+
+---Instructions---
+1. Analyze the provided ENTITIES and RELATIONSHIPS carefully to identify:
+ - Key concepts and their hierarchies
+ - Temporal sequences and chronological order
+ - Cause-and-effect relationships
+ - Dependencies between different elements
+
+2. Organize the information in a logical sequence by:
+ - Starting with foundational concepts
+ - Building up to more complex relationships
+ - Grouping related ideas together
+ - Creating clear transitions between sections
+
+3. Rephrase the text while maintaining:
+ - Logical flow and progression
+ - Clear connections between ideas
+ - Proper context and background
+ - Coherent narrative structure
+
+4. Review and refine the text to ensure:
+ - Logical consistency throughout
+ - Clear cause-and-effect relationships
+
+################
+-ENTITIES-
+################
+{entities}
+
+################
+-RELATIONSHIPS-
+################
+{relationships}
+
+"""
+
+TEMPLATE_ZH: str = """---角色---
+你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。
+使用{language}作为输出语言。
+
+---目标---
+生成文本的重述版本,使其传达与原始实体和关系描述相同的含义,同时:
+1. 遵循清晰的逻辑流和结构
+2. 建立适当的因果关系
+3. 确保时间和顺序的一致性
+4. 使用连词和适当的连接词(如"首先"、"然而"、"因此"等)创造流畅的过渡
+
+---说明---
+1. 仔细分析提供的实体和关系,以识别:
+ - 关键概念及其层级关系
+ - 时间序列和时间顺序
+ - 因果关系
+ - 不同元素之间的依赖关系
+2. 通过以下方式将信息组织成逻辑顺序:
+ - 从基础概念开始
+ - 逐步建立更复杂的关系
+ - 将相关的想法分组在一起
+ - 在各部分之间创建清晰的过渡
+3. 重述文本时保持:
+ - 逻辑流畅
+ - 概念之间的清晰联系
+ - 适当的上下文和背景
+ - 连贯的叙述结构
+4. 检查和完善文本以确保:
+ - 整体逻辑一致性
+ - 清晰的因果关系
+
+################
+-实体-
+################
+{entities}
+
+################
+-关系-
+################
+{relationships}
+
+"""
+
+REQUIREMENT_ZH = """
+################
+请在下方直接输出连贯的重述文本,不要输出任何额外的内容。
+
+重述文本:
+"""
+
+REQUIREMENT_EN = """
+################
+Please directly output the coherent rephrased text below, without any additional content.
+
+Rephrased Text:
+"""
+
+
+ANSWER_REPHRASING_PROMPT = {
+ "English": {
+ "TEMPLATE": TEMPLATE_EN + REQUIREMENT_EN,
+ "CONTEXT_TEMPLATE": TEMPLATE_CONTEXT_EN + REQUIREMENT_EN,
+ },
+ "Chinese": {
+ "TEMPLATE": TEMPLATE_ZH + REQUIREMENT_ZH,
+ "CONTEXT_TEMPLATE": TEMPLATE_CONTEXT_ZH + REQUIREMENT_ZH,
+ },
+}
diff --git a/hf-repo/graphgen/templates/coreference_resolution.py b/hf-repo/graphgen/templates/coreference_resolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc03e671411b6c1abd9b75ea9c1dbb976d7d331e
--- /dev/null
+++ b/hf-repo/graphgen/templates/coreference_resolution.py
@@ -0,0 +1,36 @@
+TEMPLATE_ZH: str = """请根据参考文本识别并消解文本中的指代词,明确每个代词所指代的具体实体,并直接输出消解后的文本。
+
+-示例-
+输入:
+小明和小红一起去公园。她们玩得很开心。之后,他们去吃冰淇淋。
+输出:
+小明和小红一起去公园。小明和小红玩得很开心。之后,小明和小红去吃冰淇淋。
+
+-真实数据-
+参考文本:
+{reference}
+输入:
+{input_sentence}
+请直接输出改写后的句子,不要输出任何额外信息。
+输出:
+"""
+
+TEMPLATE_EN: str = """Please identify and resolve the pronouns in the reference text, \
+specify the specific entities referred to by each pronoun, and directly output the resolved text.
+
+-Example-
+Input:
+John and Mary went to the park. They had a great time. Later, they went to eat ice cream.
+Output:
+John and Mary went to the park. John and Mary had a great time. Later, John and Mary went to eat ice cream.
+
+-Real Data-
+Reference text:
+{reference}
+Input:
+{input_sentence}
+Please directly output the rewritten sentence without any additional information.
+Output:
+"""
+
+COREFERENCE_RESOLUTION_PROMPT = {"en": TEMPLATE_EN, "zh": TEMPLATE_ZH}
diff --git a/hf-repo/graphgen/templates/description_rephrasing.py b/hf-repo/graphgen/templates/description_rephrasing.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0e38012347f8d24237afda47497faed34ad5e67
--- /dev/null
+++ b/hf-repo/graphgen/templates/description_rephrasing.py
@@ -0,0 +1,121 @@
+ANTI_TEMPLATE_EN: str = """-Goal-
+Transform the input sentence into its opposite meaning while:
+
+1. Preserving most of the original sentence structure
+2. Changing only key words that affect the core meaning
+3. Maintaining the same tone and style
+4. The input sentence provided is a right description, and the output sentence should be a wrong description
+5. The output sentence should be fluent and grammatically correct
+
+################
+-Examples-
+################
+Input:
+The bright sunshine made everyone feel energetic and happy.
+
+Output:
+The bright sunshine made everyone feel tired and sad.
+
+################
+-Real Data-
+################
+Input:
+{input_sentence}
+################
+Please directly output the rewritten sentence without any additional information.
+Output:
+"""
+
+ANTI_TEMPLATE_ZH: str = """-目标-
+将输入句子转换为相反含义的句子,同时:
+
+1. 保留大部分原始句子结构
+2. 仅更改影响核心含义的关键词
+3. 保持相同的语气和风格
+4. 提供的输入句子是一个正确的描述,输出句子应该是一个错误的描述
+5. 输出句子应该流畅且语法正确
+
+################
+-示例-
+################
+输入:
+明亮的阳光让每个人都感到充满活力和快乐。
+
+输出:
+明亮的阳光让每个人都感到疲惫和悲伤。
+
+################
+-真实数据-
+################
+输入:
+{input_sentence}
+################
+请直接输出改写后的句子,不要输出任何额外信息。
+输出:
+"""
+
+TEMPLATE_ZH: str = """-目标-
+将输入句子转换为相同含义的句子,同时:
+
+1. 保留大部分原始句子结构
+2. 仅更改影响核心含义的关键词
+3. 保持相同的语气和风格
+4. 输出句子应该流畅且语法正确
+
+################
+-示例-
+################
+输入:
+明亮的阳光让每个人都感到充满活力和快乐。
+
+输出:
+明媚的阳光让每个人都感受到活力与快乐。
+
+################
+-真实数据-
+################
+输入:
+{input_sentence}
+################
+请直接输出改写后的句子,不要输出任何额外信息。
+输出:
+"""
+
+TEMPLATE_EN: str = """-Goal-
+Transform the input sentence into a sentence with the same meaning while:
+
+1. Preserving most of the original sentence structure
+2. Changing only key words that affect the core meaning
+3. Maintaining the same tone and style
+4. The output sentence should be fluent and grammatically correct
+
+################
+-Examples-
+################
+Input:
+The bright sunshine made everyone feel energetic and happy.
+
+Output:
+The bright sunshine made everyone feel energetic and joyful.
+
+################
+-Real Data-
+################
+Input:
+{input_sentence}
+################
+Please directly output the rewritten sentence without any additional information.
+Output:
+"""
+
+
+DESCRIPTION_REPHRASING_PROMPT= {
+ "English": {
+ "ANTI_TEMPLATE": ANTI_TEMPLATE_EN,
+ "TEMPLATE": TEMPLATE_EN
+ },
+ "Chinese": {
+ "ANTI_TEMPLATE": ANTI_TEMPLATE_ZH,
+ "TEMPLATE": TEMPLATE_ZH
+ }
+}
diff --git a/hf-repo/graphgen/templates/kg_extraction.py b/hf-repo/graphgen/templates/kg_extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d98bb9593c1dcc6b5ddba7db06ff397dd16fb9c
--- /dev/null
+++ b/hf-repo/graphgen/templates/kg_extraction.py
@@ -0,0 +1,210 @@
+# pylint: disable=C0301
+
+TEMPLATE_EN: str = """You are an NLP expert, skilled at analyzing text to extract named entities and their relationships.
+
+-Goal-
+Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
+Use {language} as output language.
+
+-Steps-
+1. Identify all entities. For each identified entity, extract the following information:
+- entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
+- entity_type: One of the following types: [{entity_types}]
+- entity_summary: Comprehensive summary of the entity's attributes and activities
+Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
+
+2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
+For each pair of related entities, extract the following information:
+- source_entity: name of the source entity, as identified in step 1
+- target_entity: name of the target entity, as identified in step 1
+- relationship_summary: explanation as to why you think the source entity and the target entity are related to each other
+Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
+
+3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
+Format the content-level key words as ("content_keywords"{tuple_delimiter})
+
+4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
+
+5. When finished, output {completion_delimiter}
+
+################
+-Examples-
+################
+-Example 1-
+Text:
+################
+In the second century of the Christian Era, the empire of Rome comprehended the fairest part of the earth, and the most civilized portion of mankind. The frontiers of that extensive monarchy were guarded by ancient renown and disciplined valor. The gentle but powerful influence of laws and manners had gradually cemented the union of the provinces. Their peaceful inhabitants enjoyed and abused the advantages of wealth and luxury. The image of a free constitution was preserved with decent reverence: the Roman senate appeared to possess the sovereign authority, and devolved on the emperors all the executive powers of government. During a happy period of more than fourscore years, the public administration was conducted by the virtue and abilities of Nerva, Trajan, Hadrian, and the two Antonines.
+################
+Output:
+("entity"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"organization"{tuple_delimiter}"The dominant empire of the second century CE, encompassing the most developed regions of the known world."){record_delimiter}
+("entity"{tuple_delimiter}"Second Century CE"{tuple_delimiter}"date"{tuple_delimiter}"Time period of the Christian Era when the Roman Empire was at its height."){record_delimiter}
+("entity"{tuple_delimiter}"Rome"{tuple_delimiter}"location"{tuple_delimiter}"The capital and heart of the Roman Empire."){record_delimiter}
+("entity"{tuple_delimiter}"Roman Senate"{tuple_delimiter}"organization"{tuple_delimiter}"Legislative body that appeared to hold sovereign authority in Rome."){record_delimiter}
+("entity"{tuple_delimiter}"Nerva"{tuple_delimiter}"person"{tuple_delimiter}"Roman emperor who contributed to the public administration during a prosperous period."){record_delimiter}
+("entity"{tuple_delimiter}"Trajan"{tuple_delimiter}"person"{tuple_delimiter}"Roman emperor known for his virtue and administrative abilities."){record_delimiter}
+("entity"{tuple_delimiter}"Hadrian"{tuple_delimiter}"person"{tuple_delimiter}"Roman emperor who governed during the empire's peaceful period."){record_delimiter}
+("entity"{tuple_delimiter}"Antonines"{tuple_delimiter}"person"{tuple_delimiter}"Two Roman emperors who ruled during a period of prosperity and good governance."){record_delimiter}
+("entity"{tuple_delimiter}"Roman Law"{tuple_delimiter}"concept"{tuple_delimiter}"System of laws and manners that unified the provinces of the Roman Empire."){record_delimiter}
+("relationship"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"Roman Law"{tuple_delimiter}"The empire was unified and maintained through the influence of its laws and customs."){record_delimiter}
+("relationship"{tuple_delimiter}"Roman Senate"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"The Senate appeared to possess sovereign authority while delegating executive powers to emperors."){record_delimiter}
+("relationship"{tuple_delimiter}"Nerva"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"Nerva was one of the emperors who contributed to the empire's successful administration."){record_delimiter}
+("relationship"{tuple_delimiter}"Trajan"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"Trajan was one of the emperors who governed during the empire's prosperous period."){record_delimiter}
+("relationship"{tuple_delimiter}"Hadrian"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"Hadrian was one of the emperors who managed the empire's administration effectively."){record_delimiter}
+("relationship"{tuple_delimiter}"Antonines"{tuple_delimiter}"Roman Empire"{tuple_delimiter}"The Antonines were emperors who helped maintain the empire's prosperity through their governance."){record_delimiter}
+("content_keywords"{tuple_delimiter}"Roman governance, imperial prosperity, law and order, civilized society"){completion_delimiter}
+
+-Example 2-
+Text:
+#############
+Overall, the analysis of the OsDT11 sequence demonstrated that this protein belongs to the CRP family. Since OsDT11 is predicted to be a secreted protein, the subcellular localization of OsDT11 was determined by fusing the OsDT11 ORF to RFP in a p35S::RFP vector by in vivo protein targeting in NB epidermal cells by performing an Agrobacterium tumefaciens-mediated transient assay. After incubation for 48 h, the RFP signals were mainly detected in the cell-wall of OsDT11-RFP transformed cells, while the control cells (transformed with the RFP construct) displayed ubiquitous RFP signals, demonstrating that OsDT11 is a secreted signal peptide. Moreover, when the infiltrated leaf sections were plasmolyzed, the OsDT11-RFP fusion proteins were located on the cell wall.
+#############
+Output:
+("entity"{tuple_delimiter}"OsDT11"{tuple_delimiter}"gene"{tuple_delimiter}"A protein sequence belonging to the CRP family, demonstrated to be a secreted signal peptide that localizes to cell walls."){record_delimiter}
+("entity"{tuple_delimiter}"CRP family"{tuple_delimiter}"science"{tuple_delimiter}"A protein family to which OsDT11 belongs, characterized by specific structural and functional properties."){record_delimiter}
+("entity"{tuple_delimiter}"RFP"{tuple_delimiter}"technology"{tuple_delimiter}"Red Fluorescent Protein, used as a fusion marker to track protein localization in cells."){record_delimiter}
+("entity"{tuple_delimiter}"p35S::RFP vector"{tuple_delimiter}"technology"{tuple_delimiter}"A genetic construct used for protein expression and visualization studies, containing the 35S promoter and RFP marker."){record_delimiter}
+("entity"{tuple_delimiter}"NB epidermal cells"{tuple_delimiter}"nature"{tuple_delimiter}"Plant epidermal cells used as the experimental system for protein localization studies."){record_delimiter}
+("entity"{tuple_delimiter}"Agrobacterium tumefaciens"{tuple_delimiter}"nature"{tuple_delimiter}"A bacteria species used for transferring genetic material into plant cells in laboratory experiments."){record_delimiter}
+("relationship"{tuple_delimiter}"OsDT11"{tuple_delimiter}"CRP family"{tuple_delimiter}"OsDT11 is identified as a member of the CRP family through sequence analysis."){record_delimiter}
+("relationship"{tuple_delimiter}"OsDT11"{tuple_delimiter}"RFP"{tuple_delimiter}"OsDT11 was fused to RFP to study its cellular localization."){record_delimiter}
+("relationship"{tuple_delimiter}"Agrobacterium tumefaciens"{tuple_delimiter}"NB epidermal cells"{tuple_delimiter}"Agrobacterium tumefaciens was used to transfer genetic material into NB epidermal cells through a transient assay."){record_delimiter}
+("relationship"{tuple_delimiter}"OsDT11"{tuple_delimiter}"NB epidermal cells"{tuple_delimiter}"OsDT11's subcellular localization was studied in NB epidermal cells, showing cell wall targeting."){record_delimiter}
+("content_keywords"{tuple_delimiter}"protein localization, gene expression, cellular biology, molecular techniques"){completion_delimiter}
+
+################
+-Real Data-
+################
+Entity_types: {entity_types}
+Text: {input_text}
+################
+Output:
+"""
+
+
+TEMPLATE_ZH: str = """你是一个NLP专家,擅长分析文本提取命名实体和关系。
+
+-目标-
+给定一个实体类型列表和可能与列表相关的文本,从文本中识别所有这些类型的实体,以及这些实体之间所有的关系。
+使用{language}作为输出语言。
+
+-步骤-
+1. 识别所有实体。对于每个识别的实体,提取以下信息:
+ - entity_name:实体的名称,首字母大写
+ - entity_type:以下类型之一:[{entity_types}]
+ - entity_summary:实体的属性与活动的全面总结
+ 将每个实体格式化为("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
+
+2. 从步骤1中识别的实体中,识别所有(源实体,目标实体)对,这些实体彼此之间*明显相关*。
+ 对于每对相关的实体,提取以下信息:
+ - source_entity:步骤1中识别的源实体名称
+ - target_entity:步骤1中识别的目标实体名称
+ - relationship_summary:解释为什么你认为源实体和目标实体彼此相关
+ 将每个关系格式化为("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
+
+3. 识别总结整个文本的主要概念、主题或话题的高级关键词。这些应该捕捉文档中存在的总体思想。
+ 将内容级关键词格式化为("content_keywords"{tuple_delimiter})
+
+4. 以中文返回步骤1和2中识别出的所有实体和关系的输出列表。使用**{record_delimiter}**作为列表分隔符。
+
+5. 完成后,输出{completion_delimiter}
+
+################
+-示例-
+################
+-示例 1-
+文本:
+################
+鲁镇的酒店的格局,是和别处不同的:都是当街一个曲尺形的大柜台,柜里面预备着热水,可以随时温酒。做工的人,傍午傍晚散了工,每每花四文铜钱,买一碗酒,——这是二十多年前的事,现在每碗要涨到十文,——靠柜外站着,热热的喝了休息;倘肯多花一文,便可以买一碟盐煮笋,或者茴香豆,做下酒物了,如果出到十几文,那就能买一样荤菜,但这些顾客,多是短衣帮,大抵没有这样阔绰。只有穿长衫的,才踱进店面隔壁的房子里,要酒要菜,慢慢地坐喝。
+################
+输出:
+("entity"{tuple_delimiter}"鲁镇的酒店"{tuple_delimiter}"location"{tuple_delimiter}"鲁镇的酒店是一个特定地点,其格局独特,柜台形状为曲尺形,提供热水温酒服务。"){record_delimiter}
+("entity"{tuple_delimiter}"曲尺形的大柜台"{tuple_delimiter}"keyword"{tuple_delimiter}"曲尺形的大柜台是鲁镇酒店内独特的设施,用于提供服务。"){record_delimiter}
+("entity"{tuple_delimiter}"热水温酒"{tuple_delimiter}"keyword"{tuple_delimiter}"热水温酒是鲁镇酒店提供的一项服务,顾客可以随时温酒。"){record_delimiter}
+("entity"{tuple_delimiter}"做工的人"{tuple_delimiter}"person"{tuple_delimiter}"做工的人是鲁镇酒店的常客,通常在工作结束后花四文铜钱买一碗酒,有时还会买一些下酒菜。"){record_delimiter}
+("entity"{tuple_delimiter}"二十多年前的事"{tuple_delimiter}"date"{tuple_delimiter}"二十多年前的事是指过去的时间点,当时一碗酒的价格为四文铜钱。"){record_delimiter}
+("entity"{tuple_delimiter}"现在"{tuple_delimiter}"date"{tuple_delimiter}"现在是指当前的时间点,与过去相比,一碗酒的价格涨到了十文。"){record_delimiter}
+("entity"{tuple_delimiter}"短衣帮"{tuple_delimiter}"concept"{tuple_delimiter}"短衣帮是指做工的人,他们通常穿着短衣,经济条件有限。"){record_delimiter}
+("entity"{tuple_delimiter}"穿长衫的"{tuple_delimiter}"person"{tuple_delimiter}"穿长衫的是鲁镇酒店的另一类顾客,他们经济条件较好,通常会进入店面隔壁的房间慢慢喝酒吃菜。"){record_delimiter}
+("entity"{tuple_delimiter}"盐煮笋"{tuple_delimiter}"food"{tuple_delimiter}"盐煮笋是鲁镇酒店提供的一种下酒菜,顾客可以花一文铜钱购买。"){record_delimiter}
+("entity"{tuple_delimiter}"茴香豆"{tuple_delimiter}"food"{tuple_delimiter}"茴香豆是鲁镇酒店提供的另一种下酒菜,顾客可以花一文铜钱购买。"){record_delimiter}
+("entity"{tuple_delimiter}"荤菜"{tuple_delimiter}"food"{tuple_delimiter}"荤菜是鲁镇酒店提供的较为昂贵的菜品,顾客需要花十几文铜钱购买。"){record_delimiter}
+("relationship"{tuple_delimiter}"鲁镇的酒店"{tuple_delimiter}"曲尺形的大柜台"{tuple_delimiter}"鲁镇的酒店内设有一个曲尺形的大柜台,用于提供服务。"){record_delimiter}
+("relationship"{tuple_delimiter}"鲁镇的酒店"{tuple_delimiter}"热水温酒"{tuple_delimiter}"鲁镇的酒店提供热水温酒服务,顾客可以随时温酒。"){record_delimiter}
+("relationship"{tuple_delimiter}"做工的人"{tuple_delimiter}"二十多年前的事"{tuple_delimiter}"做工的人在二十多年前花四文铜钱买一碗酒,反映了当时的生活成本。"){record_delimiter}
+("relationship"{tuple_delimiter}"做工的人"{tuple_delimiter}"现在"{tuple_delimiter}"现在做工的人需要花十文铜钱买一碗酒,反映了物价的上涨。"){record_delimiter}
+("relationship"{tuple_delimiter}"做工的人"{tuple_delimiter}"短衣帮"{tuple_delimiter}"做工的人属于短衣帮,通常经济条件有限。"){record_delimiter}
+("relationship"{tuple_delimiter}"做工的人"{tuple_delimiter}"穿长衫的"{tuple_delimiter}"做工的人与穿长衫的形成对比,反映了社会阶层的差异。"){record_delimiter}
+("relationship"{tuple_delimiter}"穿长衫的"{tuple_delimiter}"鲁镇的酒店"{tuple_delimiter}"穿长衫的顾客通常会进入鲁镇酒店的房间慢慢喝酒吃菜,享受更高级的服务。"){record_delimiter}
+("content_keywords"{tuple_delimiter}"社会分层, 经济差距, 服务, 生活成本, 历史背景"){completion_delimiter}
+
+-示例 2-
+文本:
+################
+黄华占是感温型常规稻品种,2016—2017 年在铅山县汪二镇作中稻示范种植综合表现优良。结合示范情况,对黄华占的特征特性作简单总结,在此基础上提出高产栽培技术,以期为该品种的推广种植提供参考。近年来,铅山县粮食生产紧紧围绕“稳产、优质、增效”的总体要求、大力实施优质稻推广,积极引导粮食生产由增产转向提质。我国杂交水稻技术世界领先、优质稻品种众多,在市场走势方面(尤其稻米行情清淡期),人们习惯性地北涨看长粒香、南涨看黄华占。黄华占是广东省农业科学院水稻研究所以黄新占/丰华占为亲本选育而成,分别通过粤、湘、鄂、浙、桂、琼等省审定。为了更好、更快地推广黄华占水稻,铅山县分别于2016 年、2017 年在汪二镇火田村试验示范种植黄华占近 5.87 hm^2 ,综合表现优良。现将黄华占水稻的特征特性及高产栽培技术介绍如下。
+################
+输出:
+("entity"{tuple_delimiter}"黄华占"{tuple_delimiter}"work"{tuple_delimiter}"黄华占是一种感温型常规稻品种,由广东省农业科学院水稻研究所选育,通过多个省份审定,2016-2017年在铅山县汪二镇进行示范种植,表现优良。"){record_delimiter}
+("entity"{tuple_delimiter}"2016—2017年"{tuple_delimiter}"date"{tuple_delimiter}"2016—2017年是黄华占在铅山县汪二镇进行示范种植的时间段。"){record_delimiter}
+("entity"{tuple_delimiter}"铅山县"{tuple_delimiter}"location"{tuple_delimiter}"铅山县位于中国江西省,是黄华占水稻示范种植的地点之一。"){record_delimiter}
+("entity"{tuple_delimiter}"汪二镇"{tuple_delimiter}"location"{tuple_delimiter}"汪二镇是铅山县的一个镇,2016-2017年在此进行了黄华占水稻的示范种植。"){record_delimiter}
+("entity"{tuple_delimiter}"火田村"{tuple_delimiter}"location"{tuple_delimiter}"火田村是汪二镇的一个村庄,2016-2017年在此进行了黄华占水稻的试验示范种植。"){record_delimiter}
+("entity"{tuple_delimiter}"广东省农业科学院水稻研究所"{tuple_delimiter}"organization"{tuple_delimiter}"广东省农业科学院水稻研究所是中国的一个科研机构,负责黄华占水稻的选育工作。"){record_delimiter}
+("entity"{tuple_delimiter}"黄新占/丰华占"{tuple_delimiter}"work"{tuple_delimiter}"黄新占和丰华占是黄华占水稻的亲本,用于选育黄华占。"){record_delimiter}
+("entity"{tuple_delimiter}"粤、湘、鄂、浙、桂、琼等省"{tuple_delimiter}"location"{tuple_delimiter}"这些省份通过了黄华占水稻的审定,表明该品种在这些地区具有良好的适应性和推广潜力。"){record_delimiter}
+("entity"{tuple_delimiter}"高产栽培技术"{tuple_delimiter}"technology"{tuple_delimiter}"高产栽培技术是指为了提高黄华占水稻产量而采用的一系列农业技术措施。"){record_delimiter}
+("entity"{tuple_delimiter}"稳产、优质、增效"{tuple_delimiter}"concept"{tuple_delimiter}"这是铅山县粮食生产的主要目标,强调了粮食生产的稳定、质量和效益。"){record_delimiter}
+("entity"{tuple_delimiter}"优质稻推广"{tuple_delimiter}"mission"{tuple_delimiter}"优质稻推广是铅山县粮食生产的一个重要任务,旨在提高稻米的质量和市场竞争力。"){record_delimiter}
+("entity"{tuple_delimiter}"杂交水稻技术"{tuple_delimiter}"technology"{tuple_delimiter}"杂交水稻技术是中国领先的世界级农业技术,用于提高水稻的产量和质量。"){record_delimiter}
+("entity"{tuple_delimiter}"北涨看长粒香、南涨看黄华占"{tuple_delimiter}"concept"{tuple_delimiter}"这是市场对不同地区优质稻品种的习惯性关注点,北方面对长粒香,南方面对黄华占。"){record_delimiter}
+("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"2016—2017年"{tuple_delimiter}"黄华占在2016—2017年期间在铅山县进行了示范种植,展示了其优良的特性。"){record_delimiter}
+("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"铅山县"{tuple_delimiter}"黄华占在铅山县进行了示范种植,表现出了优良的适应性和产量。"){record_delimiter}
+("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"汪二镇"{tuple_delimiter}"黄华占在汪二镇进行了示范种植,这是其在铅山县示范种植的一部分。"){record_delimiter}
+("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"火田村"{tuple_delimiter}"黄华占在火田村进行了试验示范种植,这是其在汪二镇示范种植的一部分。"){record_delimiter}
+("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"广东省农业科学院水稻研究所"{tuple_delimiter}"黄华占是由广东省农业科学院水稻研究所选育的,该研究所负责其研发工作。"){record_delimiter}
+("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"黄新占/丰华占"{tuple_delimiter}"黄华占的亲本是黄新占和丰华占,这些亲本用于选育黄华占。"){record_delimiter}
+("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"粤、湘、鄂、浙、桂、琼等省"{tuple_delimiter}"黄华占通过了这些省份的审定,表明其在这些地区的适应性和推广潜力。"){record_delimiter}
+("relationship"{tuple_delimiter}"黄华占"{tuple_delimiter}"高产栽培技术"{tuple_delimiter}"高产栽培技术是为了提高黄华占水稻产量而开发的技术措施。"){record_delimiter}
+("relationship"{tuple_delimiter}"铅山县"{tuple_delimiter}"稳产、优质、增效"{tuple_delimiter}"铅山县的粮食生产目标是稳产、优质、增效,这些目标指导了黄华占的示范种植。"){record_delimiter}
+("relationship"{tuple_delimiter}"铅山县"{tuple_delimiter}"优质稻推广"{tuple_delimiter}"铅山县实施了优质稻推广计划,黄华占是该计划的一部分。"){record_delimiter}
+("relationship"{tuple_delimiter}"杂交水稻技术"{tuple_delimiter}"北涨看长粒香、南涨看黄华占"{tuple_delimiter}"杂交水稻技术的发展使得黄华占等优质稻品种在市场中受到关注。"){record_delimiter}
+("content_keywords"{tuple_delimiter}"黄华占, 水稻种植, 高产栽培技术, 优质稻推广, 地区适应性, 市场趋势, 技术影响"){completion_delimiter}
+
+-真实数据-
+实体类型:{entity_types}
+文本:{input_text}
+################
+输出:
+"""
+
+CONTINUE_EN: str = """MANY entities and relationships were missed in the last extraction. \
+Add them below using the same format:
+"""
+
+CONTINUE_ZH: str = """很多实体和关系在上一次的提取中可能被遗漏了。请在下面使用相同的格式添加它们:"""
+
+IF_LOOP_EN: str = """It appears some entities and relationships may have still been missed. \
+Answer YES | NO if there are still entities and relationships that need to be added.
+"""
+
+IF_LOOP_ZH: str = """看起来可能仍然遗漏了一些实体和关系。如果仍有实体和关系需要添加,请回答YES | NO。"""
+
+KG_EXTRACTION_PROMPT: dict = {
+ "English": {
+ "TEMPLATE": TEMPLATE_EN,
+ "CONTINUE": CONTINUE_EN,
+ "IF_LOOP": IF_LOOP_EN,
+ },
+ "Chinese": {
+ "TEMPLATE": TEMPLATE_ZH,
+ "CONTINUE": CONTINUE_ZH,
+ "IF_LOOP": IF_LOOP_ZH,
+ },
+ "FORMAT": {
+ "tuple_delimiter": "<|>",
+ "record_delimiter": "##",
+ "completion_delimiter": "<|COMPLETE|>",
+ "entity_types": "concept, date, location, keyword, organization, person, event, work, nature, artificial, \
+science, technology, mission, gene",
+ "language": "English",
+ },
+}
diff --git a/hf-repo/graphgen/templates/kg_summarization.py b/hf-repo/graphgen/templates/kg_summarization.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cf751801f6d113527d2390a88f30121325cdc75
--- /dev/null
+++ b/hf-repo/graphgen/templates/kg_summarization.py
@@ -0,0 +1,45 @@
+TEMPLATE_EN = """You are an NLP expert responsible for generating a comprehensive summary of the data provided below.
+Given one entity or relationship, and a list of descriptions, all related to the same entity or relationship.
+Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
+If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
+Make sure it is written in third person, and include the entity names so we the have full context.
+Use {language} as output language.
+
+#######
+-Data-
+Entities: {entity_name}
+Description List: {description_list}
+#######
+Output:
+"""
+
+TEMPLATE_ZH = """你是一个NLP专家,负责根据以下提供的数据生成综合摘要。
+给定一个实体或关系,以及一系列描述,所有描述都与同一实体或关系相关。
+请将所有这些描述整合成一个综合描述。确保包含所有描述中收集的信息。
+如果提供的描述是矛盾的,请解决这些矛盾并提供一个连贯的总结。
+确保以第三人称写作,并包含实体名称,以便我们有完整的上下文。
+使用{language}作为输出语言。
+
+#######
+-数据-
+实体:{entity_name}
+描述列表:{description_list}
+#######
+输出:
+"""
+
+
+KG_SUMMARIZATION_PROMPT = {
+ "Chinese": {
+ "TEMPLATE": TEMPLATE_ZH
+ },
+ "English": {
+ "TEMPLATE": TEMPLATE_EN
+ },
+ "FORMAT": {
+ "language": "English",
+ "tuple_delimiter": "<|>",
+ "record_delimiter": "##",
+ "completion_delimiter": "<|COMPLETE|>",
+ },
+}
diff --git a/hf-repo/graphgen/templates/multi_hop_generation.py b/hf-repo/graphgen/templates/multi_hop_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..dad2ee36204f8eae483a99d39d55c2b04ba879b9
--- /dev/null
+++ b/hf-repo/graphgen/templates/multi_hop_generation.py
@@ -0,0 +1,60 @@
+# pylint: disable=C0301
+
+TEMPLATE_ZH: str = """请基于以下知识子图生成多跳推理问题和答案。你将获得一个知识子图,其中包含一系列实体、关系和事实。你的任务是提出一个问题,该问题需要经过多次推理才能回答。问题的答案应该是从给定的知识子图中推断出来的。确保问题的难度适中,需要多步推理才能回答。
+
+例如:
+########
+--实体--
+1. 苹果
+2. 水果
+3. 维生素C
+########
+--关系--
+1. 苹果-水果:苹果是一种水果
+2. 水果-维生素C:水果中富含维生素C
+########
+问题:通过吃苹果补充的什么物质,有助于维持健康?
+答案:维生素C
+########
+
+#########
+--实体--
+{entities}
+#########
+--关系--
+{relationships}
+#########
+直接输出生成的问题和答案,请不要直接复制示例问题和答案,不要输出无关内容。
+"""
+
+TEMPLATE_EN: str = """Please generate a multi-hop reasoning question and answer based on the following knowledge subgraph. You will be provided with a knowledge subgraph that contains a series of entities, relations, and facts. Your task is to generate a question that requires multiple steps of reasoning to answer. The answer to the question should be inferred from the given knowledge subgraph. Ensure that the question is of moderate difficulty and requires multiple steps of reasoning to answer.
+
+For example:
+########
+--Entities--
+1. Apple
+2. Fruit
+3. Vitamin C
+########
+--Relations--
+1. Apple-Fruit: Apple is a type of fruit
+2. Fruit-Vitamin C: Fruits are rich in Vitamin C
+########
+Question: What substance, obtained through eating apples, helps maintain health?
+Answer: Vitamin C
+########
+
+########
+--Entities--
+{entities}
+########
+--Relations--
+{relationships}
+########
+Output the generated question and answer directly, please do not copy the example question and answer directly, and do not provide irrelevant information.
+"""
+
+MULTI_HOP_GENERATION_PROMPT = {
+ "English": TEMPLATE_EN,
+ "Chinese": TEMPLATE_ZH
+}
diff --git a/hf-repo/graphgen/templates/question_generation.py b/hf-repo/graphgen/templates/question_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9ca9128fad65c127f985080f887351cf2efe68e
--- /dev/null
+++ b/hf-repo/graphgen/templates/question_generation.py
@@ -0,0 +1,78 @@
+# pylint: disable=C0301
+TEMPLATE_SINGLE_EN: str = """The answer to a question is provided. Please generate a question that corresponds to the answer.
+
+################
+Answer:
+{answer}
+################
+Question:
+"""
+
+TEMPLATE_SINGLE_ZH: str = """下面提供了一个问题的答案,请生成一个与答案对应的问题。
+
+################
+答案:
+{answer}
+################
+问题:
+"""
+
+TEMPLATE_SINGLE_QA_EN: str = """You are given a text passage. Your task is to generate a question and answer (QA) pair based on the content of that text.
+The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text.
+For example:
+Question: What is the effect of overexpressing the BG1 gene on grain size and development?
+Answer: Overexpression of the BG1 gene leads to significantly increased grain size, demonstrating its role in grain development.
+
+Question: What role does TAC4 play in the gravitropism of rice shoots?
+Answer: TAC4 is a key regulator of gravitropism in rice shoots, promoting the bending of shoots towards the gravity vector.
+
+Here is the text passage you need to generate a QA pair for:
+{doc}
+"""
+
+TEMPLATE_SINGLE_QA_ZH: str = """给定一个文本段落。你的任务是根据该文本的内容生成一个问答(QA)对。
+答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。
+例如:
+问题:过表达BG1基因对谷粒大小和发育有什么影响?
+答案:BG1基因的过表达显著增加了谷粒大小,表明其在谷物发育中的作用。
+
+问题:TAC4在水稻茎的重力性状中扮演什么角色?
+答案:TAC4是水稻茎重力性状的关键调节因子,促进茎向重力矢量弯曲。
+
+以下是你需要为其生成QA对的文本段落:
+{doc}
+"""
+
+# TODO: 修改这里的prompt
+TEMPLATE_MULTI_EN = """You are an assistant to help read a article and then rephrase it in a question answering format. The user will provide you with an article with its content. You need to generate a paraphrase of the same article in question and answer format with one tag of "Question: ..." followed by "Answer: ...". Remember to keep the meaning and every content of the article intact.
+
+Here is the format you should follow for your response:
+Question:
+Answer:
+
+Here is the article you need to rephrase:
+{doc}
+"""
+
+TEMPLATE_MULTI_ZH = """你是一位助手,帮助阅读一篇文章,然后以问答格式重述它。用户将为您提供一篇带有内容的文章。你需要以一个标签"问题:..."为开头,接着是"答案:...",生成一篇与原文章相同的问答格式的重述。请确保保持文章的意义和每个内容不变。
+
+以下是你应该遵循的响应格式:
+问题: <问题>
+答案: <答案>
+
+以下是你需要重述的文章:
+{doc}
+"""
+
+QUESTION_GENERATION_PROMPT = {
+ "English": {
+ "SINGLE_TEMPLATE": TEMPLATE_SINGLE_EN,
+ "SINGLE_QA_TEMPLATE": TEMPLATE_SINGLE_QA_EN,
+ "MULTI_TEMPLATE": TEMPLATE_MULTI_EN
+ },
+ "Chinese": {
+ "SINGLE_TEMPLATE": TEMPLATE_SINGLE_ZH,
+ "SINGLE_QA_TEMPLATE": TEMPLATE_SINGLE_QA_ZH,
+ "MULTI_TEMPLATE": TEMPLATE_MULTI_ZH
+ }
+}
diff --git a/hf-repo/graphgen/templates/search_judgement.py b/hf-repo/graphgen/templates/search_judgement.py
new file mode 100644
index 0000000000000000000000000000000000000000..e85b00974990959b07db5d2e44fdc209f5435667
--- /dev/null
+++ b/hf-repo/graphgen/templates/search_judgement.py
@@ -0,0 +1,67 @@
+# pylint: disable=C0301
+
+TEMPLATE: str = """-Goal-
+Please select the most relevant search result for the given entity.
+The name and description of the entity are provided. The search results are provided as a list.
+Please select the most relevant search result from the list. If none of the search results are relevant, please select 'None of the above'.
+
+Steps:
+1. Read the name and description of the entity.
+
+2. Read the search results. For each search result, compare it with the entity name and description to determine if it is relevant.
+
+3. Select the most relevant search result from the list. If none of the search results are relevant, select 'None of the above'.
+
+4. Output your selection directly, please do not provide any additional information.
+
+################
+-Examples-
+################
+{input_examples}
+
+################
+-Real Data-
+################
+Entity_name: {entity_name}
+Description: {description}
+Search Results:
+{search_results}
+################
+Output:
+"""
+
+EXAMPLES = [
+ """Example 1:
+################
+Entity_name: Java
+Description: Java is a high-level programming language developed by Sun Microsystems. It is used to create web applications, mobile applications, and enterprise software.
+Search Results:
+1. Java (programming language)
+2. Java (island)
+3. Java (software platform)
+4. Java (drink)
+5. Java (disambiguation)
+6. None of the above
+################
+Output:
+1
+################""",
+ """Example 2:
+################
+Entity_name: Apple
+Description: Apple Inc. is an American multinational technology company that designs, manufactures, and sells consumer electronics, computer software, and online services.
+Search Results:
+1. Apple (fruit)
+2. Apple Inc.
+3. Apple (disambiguation)
+4. None of the above
+################
+Output:
+2
+################""",
+]
+
+SEARCH_JUDGEMENT_PROMPT = {
+ "TEMPLATE": TEMPLATE,
+ "EXAMPLES": EXAMPLES,
+}
diff --git a/hf-repo/graphgen/templates/statement_judgement.py b/hf-repo/graphgen/templates/statement_judgement.py
new file mode 100644
index 0000000000000000000000000000000000000000..5af4468dd2611b08a539651a6d59ea51de770349
--- /dev/null
+++ b/hf-repo/graphgen/templates/statement_judgement.py
@@ -0,0 +1,13 @@
+TEMPLATE: str = """Please determine if the following statement is correct.
+
+Note:
+1. If the statement is correct, please reply with 'yes', otherwise reply with 'no'.
+2. The answer should be either 'yes' or 'no', do not output any other content.
+
+Statement:
+{statement}
+Judgement: """
+
+STATEMENT_JUDGEMENT_PROMPT = {
+ "TEMPLATE": TEMPLATE
+}
diff --git a/hf-repo/graphgen/utils/__init__.py b/hf-repo/graphgen/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3c8e1e646661739ae64e8accedb8f515a4b220c
--- /dev/null
+++ b/hf-repo/graphgen/utils/__init__.py
@@ -0,0 +1,16 @@
+from .calculate_confidence import yes_no_loss_entropy
+from .detect_lang import detect_if_chinese, detect_main_language
+from .file import read_file
+from .format import (
+ format_generation_results,
+ handle_single_entity_extraction,
+ handle_single_relationship_extraction,
+ load_json,
+ pack_history_conversations,
+ split_string_by_multi_markers,
+ write_json,
+)
+from .hash import compute_args_hash, compute_content_hash
+from .help_nltk import NLTKHelper
+from .log import logger, parse_log, set_logger
+from .loop import create_event_loop
diff --git a/hf-repo/graphgen/utils/calculate_confidence.py b/hf-repo/graphgen/utils/calculate_confidence.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b596d910c93f35ed0b133607c5e209201bfc05a
--- /dev/null
+++ b/hf-repo/graphgen/utils/calculate_confidence.py
@@ -0,0 +1,54 @@
+import math
+from typing import List
+from graphgen.models.llm.topk_token_model import Token
+
+def preprocess_tokens(tokens: List[Token]) -> List[Token]:
+ """Preprocess tokens for calculating confidence."""
+ tokens = [x for x in tokens if x.prob > 0]
+ return tokens
+
+def joint_probability(tokens: List[Token]) -> float:
+ """Calculate joint probability of a list of tokens."""
+ tokens = preprocess_tokens(tokens)
+ logprob_sum = sum(x.logprob for x in tokens)
+ return math.exp(logprob_sum / len(tokens))
+
+def min_prob(tokens: List[Token]) -> float:
+ """Calculate the minimum probability of a list of tokens."""
+ tokens = preprocess_tokens(tokens)
+ return min(x.prob for x in tokens)
+
+def average_prob(tokens: List[Token]) -> float:
+ """Calculate the average probability of a list of tokens."""
+ tokens = preprocess_tokens(tokens)
+ return sum(x.prob for x in tokens) / len(tokens)
+
+def average_confidence(tokens: List[Token]) -> float:
+ """Calculate the average confidence of a list of tokens."""
+ tokens = preprocess_tokens(tokens)
+ confidence = [x.prob / sum(y.prob for y in x.top_candidates[:5]) for x in tokens]
+ return sum(confidence) / len(tokens)
+
+def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> float:
+ """Calculate the loss for yes/no question."""
+ losses = []
+ for i, tokens in enumerate(tokens_list):
+ token = tokens[0]
+ assert token.text.lower() in ["yes", "no"]
+ if token.text == ground_truth[i]:
+ losses.append(1 - token.prob)
+ else:
+ losses.append(token.prob)
+ return sum(losses) / len(losses)
+
+def yes_no_loss_entropy(tokens_list: List[List[Token]], ground_truth: List[str]) -> float:
+ """Calculate the loss for yes/no question using entropy."""
+ losses = []
+ for i, tokens in enumerate(tokens_list):
+ token = tokens[0]
+ assert token.text.lower() in ["yes", "no"]
+ if token.text == ground_truth[i]:
+ losses.append(-math.log(token.prob))
+ else:
+ losses.append(-math.log(1 - token.prob))
+ return sum(losses) / len(losses)
diff --git a/hf-repo/graphgen/utils/detect_lang.py b/hf-repo/graphgen/utils/detect_lang.py
new file mode 100644
index 0000000000000000000000000000000000000000..c34ddac48e318925b3b6a04258824d0dc9a4ce63
--- /dev/null
+++ b/hf-repo/graphgen/utils/detect_lang.py
@@ -0,0 +1,40 @@
+def detect_main_language(text):
+ """
+ 识别文本的主要语言
+
+ :param text:
+ :return:
+ """
+ assert isinstance(text, str)
+ def is_chinese_char(char):
+ return '\u4e00' <= char <= '\u9fff'
+
+ def is_english_char(char):
+ return char.isascii() and char.isalpha()
+
+ # 去除空格和标点符号
+ text = ''.join(char for char in text if char.strip())
+
+ chinese_count = sum(1 for char in text if is_chinese_char(char))
+ english_count = sum(1 for char in text if is_english_char(char))
+
+ total = chinese_count + english_count
+ if total == 0:
+ return 'en'
+
+ chinese_ratio = chinese_count / total
+
+ if chinese_ratio >= 0.5:
+ return 'zh'
+ return 'en'
+
+def detect_if_chinese(text):
+ """
+ 判断文本是否包含有中文
+
+ :param text:
+ :return:
+ """
+
+ assert isinstance(text, str)
+ return any('\u4e00' <= char <= '\u9fff' for char in text)
diff --git a/hf-repo/graphgen/utils/format.py b/hf-repo/graphgen/utils/format.py
new file mode 100644
index 0000000000000000000000000000000000000000..abc34c874a5b413a478e513d9f5109241f36c8a8
--- /dev/null
+++ b/hf-repo/graphgen/utils/format.py
@@ -0,0 +1,134 @@
+import html
+import json
+import os
+import re
+from typing import Any
+
+from .log import logger
+
+
+def pack_history_conversations(*args: str):
+ roles = ["user", "assistant"]
+ return [
+ {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
+ ]
+
+
+def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
+ """Split a string by multiple markers"""
+ if not markers:
+ return [content]
+ results = re.split("|".join(re.escape(marker) for marker in markers), content)
+ return [r.strip() for r in results if r.strip()]
+
+
+# Refer the utils functions of the official GraphRAG implementation:
+# https://github.com/microsoft/graphrag
+def clean_str(input: Any) -> str:
+ """Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
+ # If we get non-string input, just give it back
+ if not isinstance(input, str):
+ return input
+
+ result = html.unescape(input.strip())
+ # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
+ return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
+
+
+async def handle_single_entity_extraction(
+ record_attributes: list[str],
+ chunk_key: str,
+):
+ if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
+ return None
+ # add this record as a node in the G
+ entity_name = clean_str(record_attributes[1].upper())
+ if not entity_name.strip():
+ return None
+ entity_type = clean_str(record_attributes[2].upper())
+ entity_description = clean_str(record_attributes[3])
+ entity_source_id = chunk_key
+ return {
+ "entity_name": entity_name,
+ "entity_type": entity_type,
+ "description": entity_description,
+ "source_id": entity_source_id,
+ }
+
+
+def is_float_regex(value):
+ return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
+
+
+async def handle_single_relationship_extraction(
+ record_attributes: list[str],
+ chunk_key: str,
+):
+ if len(record_attributes) < 4 or record_attributes[0] != '"relationship"':
+ return None
+ # add this record as edge
+ source = clean_str(record_attributes[1].upper())
+ target = clean_str(record_attributes[2].upper())
+ edge_description = clean_str(record_attributes[3])
+
+ edge_source_id = chunk_key
+ return {
+ "src_id": source,
+ "tgt_id": target,
+ "description": edge_description,
+ "source_id": edge_source_id,
+ }
+
+
+def load_json(file_name):
+ if not os.path.exists(file_name):
+ return None
+ with open(file_name, encoding="utf-8") as f:
+ return json.load(f)
+
+
+def write_json(json_obj, file_name):
+ if not os.path.exists(os.path.dirname(file_name)):
+ os.makedirs(os.path.dirname(file_name), exist_ok=True)
+ with open(file_name, "w", encoding="utf-8") as f:
+ json.dump(json_obj, f, indent=4, ensure_ascii=False)
+
+
+def format_generation_results(
+ results: dict[str, Any], output_data_format: str
+) -> list[dict[str, Any]]:
+ if output_data_format == "Alpaca":
+ logger.info("Output data format: Alpaca")
+ results = [
+ {
+ "instruction": item["question"],
+ "input": "",
+ "output": item["answer"],
+ }
+ for item in list(results.values())
+ ]
+ elif output_data_format == "Sharegpt":
+ logger.info("Output data format: Sharegpt")
+ results = [
+ {
+ "conversations": [
+ {"from": "human", "value": item["question"]},
+ {"from": "gpt", "value": item["answer"]},
+ ]
+ }
+ for item in list(results.values())
+ ]
+ elif output_data_format == "ChatML":
+ logger.info("Output data format: ChatML")
+ results = [
+ {
+ "messages": [
+ {"role": "user", "content": item["question"]},
+ {"role": "assistant", "content": item["answer"]},
+ ]
+ }
+ for item in list(results.values())
+ ]
+ else:
+ raise ValueError(f"Unknown output data format: {output_data_format}")
+ return results
diff --git a/hf-repo/graphgen/utils/hash.py b/hf-repo/graphgen/utils/hash.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf93ec5fee27bd9f83bba1c7575a5cc8602b57e7
--- /dev/null
+++ b/hf-repo/graphgen/utils/hash.py
@@ -0,0 +1,7 @@
+from hashlib import md5
+
+def compute_args_hash(*args):
+ return md5(str(args).encode()).hexdigest()
+
+def compute_content_hash(content, prefix: str = ""):
+ return prefix + md5(content.encode()).hexdigest()
diff --git a/hf-repo/graphgen/utils/help_nltk.py b/hf-repo/graphgen/utils/help_nltk.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d2610ba63c0af563b1b83f22a19bfdacca3957e
--- /dev/null
+++ b/hf-repo/graphgen/utils/help_nltk.py
@@ -0,0 +1,39 @@
+import os
+from typing import Dict, List, Optional
+import nltk
+import jieba
+
+resource_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources")
+
+
+class NLTKHelper:
+ _stopwords: Dict[str, Optional[List[str]]] = {
+ "english": None,
+ "chinese": None,
+ }
+
+ def __init__(self):
+ jieba.initialize()
+
+ def get_stopwords(self, lang: str) -> List[str]:
+ nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
+ if self._stopwords[lang] is None:
+ try:
+ nltk.data.find("corpora/stopwords")
+ except LookupError:
+ nltk.download("stopwords", download_dir=os.path.join(resource_path, "nltk_data"))
+
+ self._stopwords[lang] = nltk.corpus.stopwords.words(lang)
+ return self._stopwords[lang]
+
+ @staticmethod
+ def word_tokenize(text: str, lang: str) -> List[str]:
+ if lang == "zh":
+ return jieba.lcut(text)
+ nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
+ try:
+ nltk.data.find("tokenizers/punkt_tab")
+ except LookupError:
+ nltk.download("punkt_tab", download_dir=os.path.join(resource_path, "nltk_data"))
+
+ return nltk.word_tokenize(text)
diff --git a/hf-repo/graphgen/utils/log.py b/hf-repo/graphgen/utils/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..32b9bac6b43524c13ba53fa43c641ed87a23ebcc
--- /dev/null
+++ b/hf-repo/graphgen/utils/log.py
@@ -0,0 +1,32 @@
+import logging
+
+logger = logging.getLogger("graphgen")
+
+def set_logger(log_file: str, log_level: int = logging.INFO, if_stream: bool = True):
+ logger.setLevel(log_level)
+
+ formatter = logging.Formatter(
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+ )
+
+ file_handler = logging.FileHandler(log_file, mode='w')
+ file_handler.setLevel(log_level)
+ file_handler.setFormatter(formatter)
+
+ stream_handler = None
+
+ if if_stream:
+ stream_handler = logging.StreamHandler()
+ stream_handler.setLevel(log_level)
+ stream_handler.setFormatter(formatter)
+
+ if not logger.handlers:
+ logger.addHandler(file_handler)
+ if if_stream and stream_handler:
+ logger.addHandler(stream_handler)
+
+
+def parse_log(log_file: str):
+ with open(log_file, "r", encoding='utf-8') as f:
+ lines = f.readlines()
+ return lines
diff --git a/hf-repo/graphgen/utils/loop.py b/hf-repo/graphgen/utils/loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f12fa5bb61d01ce1f3fa7e813b22cfb3bd93795
--- /dev/null
+++ b/hf-repo/graphgen/utils/loop.py
@@ -0,0 +1,28 @@
+import asyncio
+
+from .log import logger
+
+
+def create_event_loop() -> asyncio.AbstractEventLoop:
+ """
+ Ensure that there is always an event loop available.
+
+ This function tries to get the current event loop. If the current event loop is closed or does not exist,
+ it creates a new event loop and sets it as the current event loop.
+
+ Returns:
+ asyncio.AbstractEventLoop: The current or newly created event loop.
+ """
+ try:
+ # Try to get the current event loop
+ current_loop = asyncio.get_event_loop()
+ if current_loop.is_closed():
+ raise RuntimeError("Event loop is closed.")
+ return current_loop
+
+ except RuntimeError:
+ # If no event loop exists or it is closed, create a new one
+ logger.info("Creating a new event loop in main thread.")
+ new_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(new_loop)
+ return new_loop
diff --git a/hf-repo/graphgen/version.py b/hf-repo/graphgen/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..73315e64e6b9b14bcd5527b691c7f824dd0c9543
--- /dev/null
+++ b/hf-repo/graphgen/version.py
@@ -0,0 +1,28 @@
+
+from typing import Tuple
+
+__version__ = '20250416'
+short_version = __version__
+
+
+def parse_version_info(version_str: str) -> Tuple:
+ """Parse version from a string.
+
+ Args:
+ version_str (str): A string represents a version info.
+
+ Returns:
+ tuple: A sequence of integer and string represents version.
+ """
+ _version_info = []
+ for x in version_str.split('.'):
+ if x.isdigit():
+ _version_info.append(int(x))
+ elif x.find('rc') != -1:
+ patch_version = x.split('rc')
+ _version_info.append(int(patch_version[0]))
+ _version_info.append(f'rc{patch_version[1]}')
+ return tuple(_version_info)
+
+
+version_info = parse_version_info(__version__)
diff --git a/hf-repo/requirements.txt b/hf-repo/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cf0674d86f18b8944d65abca883ca0467a626e24
--- /dev/null
+++ b/hf-repo/requirements.txt
@@ -0,0 +1,29 @@
+tqdm
+openai
+python-dotenv
+numpy
+networkx
+graspologic
+tiktoken
+pyecharts
+wikipedia
+tenacity
+nltk
+jieba
+plotly
+pandas
+gradio>=5.25.0
+gradio-i18n==0.3.0
+kaleido
+pyyaml
+langcodes
+requests
+fastapi
+trafilatura
+
+leidenalg
+igraph
+python-louvain
+
+# For visualization
+matplotlib
\ No newline at end of file
diff --git a/hf-repo/webui/__init__.py b/hf-repo/webui/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hf-repo/webui/base.py b/hf-repo/webui/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..32f3ed1094c036108e320d720b6d2ff5ff74daa8
--- /dev/null
+++ b/hf-repo/webui/base.py
@@ -0,0 +1,31 @@
+from dataclasses import dataclass
+from typing import Any
+
+@dataclass
+class GraphGenParams:
+ """
+ GraphGen parameters
+ """
+ if_trainee_model: bool
+ input_file: str
+ tokenizer: str
+ qa_form: str
+ bidirectional: bool
+ expand_method: str
+ max_extra_edges: int
+ max_tokens: int
+ max_depth: int
+ edge_sampling: str
+ isolated_node_strategy: str
+ loss_strategy: str
+ synthesizer_url: str
+ synthesizer_model: str
+ trainee_model: str
+ api_key: str
+ chunk_size: int
+ rpm: int
+ tpm: int
+ quiz_samples: int
+ trainee_url: str
+ trainee_api_key: str
+ token_counter: Any
diff --git a/hf-repo/webui/cache_utils.py b/hf-repo/webui/cache_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..96c7d4d9d414d7c2487f0ebf4f75e12729de5cc0
--- /dev/null
+++ b/hf-repo/webui/cache_utils.py
@@ -0,0 +1,21 @@
+import os
+import uuid
+import shutil
+
+def setup_workspace(folder):
+ request_id = str(uuid.uuid4())
+ os.makedirs(folder, exist_ok=True)
+
+ working_dir = os.path.join(folder, request_id)
+ os.makedirs(working_dir, exist_ok=True)
+
+ log_dir = os.path.join(folder, "logs")
+ os.makedirs(log_dir, exist_ok=True)
+ log_file = os.path.join(log_dir, f"{request_id}.log")
+
+ return log_file, working_dir
+
+
+def cleanup_workspace(folder):
+ if os.path.exists(folder):
+ shutil.rmtree(folder)
diff --git a/hf-repo/webui/count_tokens.py b/hf-repo/webui/count_tokens.py
new file mode 100644
index 0000000000000000000000000000000000000000..53bed59a38dbc65f086d04df6dab376e13683412
--- /dev/null
+++ b/hf-repo/webui/count_tokens.py
@@ -0,0 +1,60 @@
+import os
+import sys
+import json
+import pandas as pd
+
+# pylint: disable=wrong-import-position
+root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.append(root_dir)
+from graphgen.models import Tokenizer
+
+def count_tokens(file, tokenizer_name, data_frame):
+ if not file or not os.path.exists(file):
+ return data_frame
+
+ if file.endswith(".jsonl"):
+ with open(file, "r", encoding='utf-8') as f:
+ data = [json.loads(line) for line in f]
+ elif file.endswith(".json"):
+ with open(file, "r", encoding='utf-8') as f:
+ data = json.load(f)
+ data = [item for sublist in data for item in sublist]
+ elif file.endswith(".txt"):
+ with open(file, "r", encoding='utf-8') as f:
+ data = f.read()
+ chunks = [
+ data[i:i + 512] for i in range(0, len(data), 512)
+ ]
+ data = [{"content": chunk} for chunk in chunks]
+ else:
+ raise ValueError(f"Unsupported file type: {file}")
+
+ tokenizer = Tokenizer(tokenizer_name)
+
+ # Count tokens
+ token_count = 0
+
+ for item in data:
+ if isinstance(item, dict):
+ content = item.get("content", "")
+ else:
+ content = item
+ token_count += len(tokenizer.encode_string(content))
+
+ _update_data = [[
+ str(token_count),
+ str(token_count * 50),
+ "N/A"
+ ]]
+
+ try:
+ new_df = pd.DataFrame(
+ _update_data,
+ columns=data_frame.columns
+ )
+ data_frame = new_df
+
+ except Exception as e: # pylint: disable=broad-except
+ print("[ERROR] DataFrame操作异常:", str(e))
+
+ return data_frame
diff --git a/resources/examples/chunked_demo.json b/hf-repo/webui/examples/chunked_demo.json
similarity index 100%
rename from resources/examples/chunked_demo.json
rename to hf-repo/webui/examples/chunked_demo.json
diff --git a/resources/examples/raw_demo.jsonl b/hf-repo/webui/examples/raw_demo.jsonl
similarity index 100%
rename from resources/examples/raw_demo.jsonl
rename to hf-repo/webui/examples/raw_demo.jsonl
diff --git a/resources/examples/txt_demo.txt b/hf-repo/webui/examples/txt_demo.txt
similarity index 100%
rename from resources/examples/txt_demo.txt
rename to hf-repo/webui/examples/txt_demo.txt
diff --git a/hf-repo/webui/test_api.py b/hf-repo/webui/test_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..37727b531fe98bf4b3dff14d453b868a104c4436
--- /dev/null
+++ b/hf-repo/webui/test_api.py
@@ -0,0 +1,16 @@
+from openai import OpenAI
+import gradio as gr
+
+def test_api_connection(api_base, api_key, model_name):
+ client = OpenAI(api_key=api_key, base_url=api_base)
+ try:
+ response = client.chat.completions.create(
+ model=model_name,
+ messages=[{"role": "user", "content": "test"}],
+ max_tokens=1
+ )
+ if not response.choices or not response.choices[0].message:
+ raise gr.Error(f"{model_name}: Invalid response from API")
+ gr.Success(f"{model_name}: API connection successful")
+ except Exception as e:
+ raise gr.Error(f"{model_name}: API connection failed: {str(e)}")
diff --git a/hf-repo/webui/translation.json b/hf-repo/webui/translation.json
new file mode 100644
index 0000000000000000000000000000000000000000..fef5d57976af98effa50a7e9358210ebdcd80b1e
--- /dev/null
+++ b/hf-repo/webui/translation.json
@@ -0,0 +1,36 @@
+{
+ "en": {
+ "Title": "✨Easy-to-use LLM Training Data Generation Framework✨",
+ "Intro": "is a framework for synthetic data generation guided by knowledge graphs, designed to tackle challenges for knowledge-intensive QA generation. \n\nBy uploading your text chunks (such as knowledge in agriculture, healthcare, or marine science) and filling in the LLM API key, you can generate the training data required by **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)** and **[xtuner](https://github.com/InternLM/xtuner)** online. We will automatically delete user information after completion.",
+ "Use Trainee Model": "Use Trainee Model to identify knowledge blind spots, please keep disable for SiliconCloud",
+ "Synthesizer URL Info": "Base URL for the Synthesizer Model API, use SiliconFlow as default",
+ "Trainee URL Info": "Base URL for the Trainee Model API, use SiliconFlow as default",
+ "Synthesizer Model Info": "Model for constructing KGs and generating QAs",
+ "Trainee Model Info": "Model for training",
+ "Model Config": "Model Configuration",
+ "Generation Config": "Generation Config",
+ "SiliconCloud Token": "SiliconCloud API Key",
+ "SiliconCloud Token for Trainee Model": "SiliconCloud API Key for Trainee Model",
+ "Test Connection": "Test Connection",
+ "Run GraphGen": "Run GraphGen",
+ "Upload File": "Upload File",
+ "Example Files": "Example Files"
+ },
+ "zh": {
+ "Title": "✨开箱即用的LLM训练数据生成框架✨",
+ "Intro": "是一个基于知识图谱的数据合成框架,旨在知识密集型任务中生成问答。\n\n 上传你的文本块(如农业、医疗、海洋知识),填写 LLM api key,即可在线生成 **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)**、**[xtuner](https://github.com/InternLM/xtuner)** 所需训练数据。结束后我们将自动删除用户信息。",
+ "Use Trainee Model": "使用Trainee Model来识别知识盲区,使用硅基流动时请保持禁用",
+ "Synthesizer URL Info": "调用合成模型API的URL,默认使用硅基流动",
+ "Trainee URL Info": "调用学生模型API的URL,默认使用硅基流动",
+ "Synthesizer Model Info": "用于构建知识图谱和生成问答的模型",
+ "Trainee Model Info": "用于训练的模型",
+ "Model Config": "模型配置",
+ "Generation Config": "生成配置",
+ "SiliconCloud Token": "硅基流动 API Key",
+ "SiliconCloud Token for Trainee Model": "硅基流动 API Key (学生模型)",
+ "Test Connection": "测试接口",
+ "Run GraphGen": "运行GraphGen",
+ "Upload File": "上传文件",
+ "Example Files": "示例文件"
+ }
+}
\ No newline at end of file
diff --git a/resources/nltk_data/corpora/stopwords/chinese b/nltk_data/corpora/stopwords/chinese
similarity index 100%
rename from resources/nltk_data/corpora/stopwords/chinese
rename to nltk_data/corpora/stopwords/chinese
diff --git a/resources/nltk_data/corpora/stopwords/english b/nltk_data/corpora/stopwords/english
similarity index 100%
rename from resources/nltk_data/corpora/stopwords/english
rename to nltk_data/corpora/stopwords/english
diff --git a/resources/nltk_data/tokenizers/punkt_tab/english/abbrev_types.txt b/nltk_data/tokenizers/punkt_tab/english/abbrev_types.txt
similarity index 100%
rename from resources/nltk_data/tokenizers/punkt_tab/english/abbrev_types.txt
rename to nltk_data/tokenizers/punkt_tab/english/abbrev_types.txt
diff --git a/resources/nltk_data/tokenizers/punkt_tab/english/collocations.tab b/nltk_data/tokenizers/punkt_tab/english/collocations.tab
similarity index 100%
rename from resources/nltk_data/tokenizers/punkt_tab/english/collocations.tab
rename to nltk_data/tokenizers/punkt_tab/english/collocations.tab
diff --git a/resources/nltk_data/tokenizers/punkt_tab/english/ortho_context.tab b/nltk_data/tokenizers/punkt_tab/english/ortho_context.tab
similarity index 100%
rename from resources/nltk_data/tokenizers/punkt_tab/english/ortho_context.tab
rename to nltk_data/tokenizers/punkt_tab/english/ortho_context.tab
diff --git a/resources/nltk_data/tokenizers/punkt_tab/english/sent_starters.txt b/nltk_data/tokenizers/punkt_tab/english/sent_starters.txt
similarity index 100%
rename from resources/nltk_data/tokenizers/punkt_tab/english/sent_starters.txt
rename to nltk_data/tokenizers/punkt_tab/english/sent_starters.txt
diff --git a/requirements.txt b/requirements.txt
index ab329cb596c8420378a15c85acb03fe88e4d52fb..cf0674d86f18b8944d65abca883ca0467a626e24 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,3 +17,13 @@ gradio-i18n==0.3.0
kaleido
pyyaml
langcodes
+requests
+fastapi
+trafilatura
+
+leidenalg
+igraph
+python-louvain
+
+# For visualization
+matplotlib
\ No newline at end of file
diff --git a/scripts/baselines/generate_all_baselines.sh b/scripts/baselines/generate_all_baselines.sh
deleted file mode 100644
index 8536978e1d53f9d77a383afb4cbab95820dff32b..0000000000000000000000000000000000000000
--- a/scripts/baselines/generate_all_baselines.sh
+++ /dev/null
@@ -1,7 +0,0 @@
-# generate all baselines at one go
-
-bash scripts/baselines/generate_wrap.sh
-bash scripts/baselines/generate_selfqa.sh
-bash scripts/baselines/generate_longform.sh
-bash scripts/baselines/generate_genie.sh
-bash scripts/baselines/generate_entigraph.sh
\ No newline at end of file
diff --git a/scripts/baselines/generate_entigraph.sh b/scripts/baselines/generate_entigraph.sh
deleted file mode 100644
index ce9cc991bdfbf737abf461af7b75b3d42483806c..0000000000000000000000000000000000000000
--- a/scripts/baselines/generate_entigraph.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-python3 -m baselines.EntiGraph.entigraph --input_file resources/examples/raw_demo.jsonl \
- --data_type raw \
- --output_file cache/data/entigraph.json \
diff --git a/scripts/baselines/generate_genie.sh b/scripts/baselines/generate_genie.sh
deleted file mode 100644
index 0119930dbe99639793589482e4ca07226ad58331..0000000000000000000000000000000000000000
--- a/scripts/baselines/generate_genie.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-python3 -m baselines.Genie.genie --input_file resources/examples/raw_demo.jsonl \
- --data_type raw \
- --output_file cache/data/genie.json \
diff --git a/scripts/baselines/generate_longform.sh b/scripts/baselines/generate_longform.sh
deleted file mode 100644
index d7ed70c515c22e6f0b624eace4eaf5c7320d35f4..0000000000000000000000000000000000000000
--- a/scripts/baselines/generate_longform.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-python3 -m baselines.LongForm.longform --input_file resources/examples/raw_demo.jsonl \
- --data_type raw \
- --output_file cache/data/longform.json \
diff --git a/scripts/baselines/generate_selfqa.sh b/scripts/baselines/generate_selfqa.sh
deleted file mode 100644
index 18eb7b1f9fdff1fd41857ea791218d5cba6ee3b1..0000000000000000000000000000000000000000
--- a/scripts/baselines/generate_selfqa.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-python3 -m baselines.SELF-QA.self-qa --input_file resources/examples/raw_demo.jsonl \
- --data_type raw \
- --output_file cache/data/self-qa.json \
diff --git a/scripts/baselines/generate_wrap.sh b/scripts/baselines/generate_wrap.sh
deleted file mode 100644
index f10857a58cfa8ecc6aed0e8794e8f8ccbb1fcc8b..0000000000000000000000000000000000000000
--- a/scripts/baselines/generate_wrap.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-python3 -m baselines.Wrap.wrap --input_file resources/examples/raw_demo.jsonl \
- --data_type raw \
- --output_file cache/data/wrap.json \
diff --git a/scripts/evaluate.sh b/scripts/evaluate.sh
deleted file mode 100644
index 25706d4866b9ae4c218a0fe2d1195b0d4befd075..0000000000000000000000000000000000000000
--- a/scripts/evaluate.sh
+++ /dev/null
@@ -1,4 +0,0 @@
-python3 -m graphgen.evaluate --folder cache/data \
- --output cache/output \
- --reward "OpenAssistant/reward-model-deberta-v3-large-v2,BAAI/IndustryCorpus2_DataRater" \
- --uni MingZhong/unieval-sum \
diff --git a/scripts/generate.sh b/scripts/generate.sh
deleted file mode 100644
index be0bee9b2342e178c44e6385b17d11c14a0fbdd5..0000000000000000000000000000000000000000
--- a/scripts/generate.sh
+++ /dev/null
@@ -1 +0,0 @@
-python3 -m graphgen.generate --config_file graphgen/configs/graphgen_config.yaml --output_dir cache/
diff --git a/scripts/judge.sh b/scripts/judge.sh
deleted file mode 100644
index f6fc134e4fdc15b3e55a6eb75cbf66e40eabc4d4..0000000000000000000000000000000000000000
--- a/scripts/judge.sh
+++ /dev/null
@@ -1,2 +0,0 @@
-python3 -m graphgen.judge --input cache \
- --output cache/output/new_graph.graphml \
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 3dee7f8be40c05459db9ffaed23f2022f4bbfabd..0000000000000000000000000000000000000000
--- a/setup.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import os
-
-from setuptools import find_packages, setup
-
-pwd = os.path.dirname(__file__)
-version_file = 'graphgen/version.py'
-
-
-def readme():
- with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f:
- content = f.read()
- return content
-
-
-def get_version():
- with open(os.path.join(pwd, version_file), 'r') as f:
- exec(compile(f.read(), version_file, 'exec'))
- return locals()['__version__']
-
-
-def read_requirements():
- lines = []
- with open('requirements.txt', 'r') as f:
- for line in f.readlines():
- if line.startswith('#'):
- continue
- if 'textract' in line:
- continue
- if len(line) > 0:
- lines.append(line)
- return lines
-
-
-install_packages = read_requirements()
-
-if __name__ == '__main__':
- setup(
- name='graphg',
- version=get_version(),
- url='https://github.com/open-sciencelab/GraphGen',
- description= # noqa E251
- 'GraphGen: Enhancing Supervised Fine-Tuning for LLMs with Knowledge-Driven Synthetic Data Generation', # noqa E501
- long_description=readme(),
- long_description_content_type='text/markdown',
- author='open-sciencelab',
- author_email='open-sciencelab@pjlab.org.cn',
- packages=find_packages(exclude=["models"]),
- package_data={
- 'GraphGen': ['configs/*']
- },
- include_package_data=True,
- install_requires=install_packages,
- classifiers=[
- 'Programming Language :: Python :: 3.8',
- 'Programming Language :: Python :: 3.9',
- 'Programming Language :: Python :: 3.10',
- 'Programming Language :: Python :: 3.11',
- 'Programming Language :: Python :: 3.12',
- 'Intended Audience :: Developers',
- 'Intended Audience :: Education',
- 'Intended Audience :: Science/Research',
- ],
- entry_points={'console_scripts': ['graphgen=graphgen.generate:main']},
- )
diff --git a/webui/app.py b/webui/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..10a914cefb4727756c1d93ec98fb20e3fbf0c081
--- /dev/null
+++ b/webui/app.py
@@ -0,0 +1,586 @@
+import json
+import os
+import sys
+import tempfile
+
+import gradio as gr
+import pandas as pd
+from base import GraphGenParams
+from cache_utils import cleanup_workspace, setup_workspace
+from count_tokens import count_tokens
+from gradio_i18n import Translate
+from gradio_i18n import gettext as _
+from test_api import test_api_connection
+
+# pylint: disable=wrong-import-position
+root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.append(root_dir)
+
+from graphgen.graphgen import GraphGen
+from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
+from graphgen.models.llm.limitter import RPM, TPM
+from graphgen.utils import set_logger
+
+css = """
+.center-row {
+ display: flex;
+ justify-content: center;
+ align-items: center;
+}
+"""
+
+
+def init_graph_gen(config: dict, env: dict) -> GraphGen:
+ # Set up working directory
+ log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
+
+ set_logger(log_file, if_stream=False)
+ graph_gen = GraphGen(working_dir=working_dir)
+
+ # Set up LLM clients
+ graph_gen.synthesizer_llm_client = OpenAIModel(
+ model_name=env.get("SYNTHESIZER_MODEL", ""),
+ base_url=env.get("SYNTHESIZER_BASE_URL", ""),
+ api_key=env.get("SYNTHESIZER_API_KEY", ""),
+ request_limit=True,
+ rpm=RPM(env.get("RPM", 1000)),
+ tpm=TPM(env.get("TPM", 50000)),
+ )
+
+ graph_gen.trainee_llm_client = OpenAIModel(
+ model_name=env.get("TRAINEE_MODEL", ""),
+ base_url=env.get("TRAINEE_BASE_URL", ""),
+ api_key=env.get("TRAINEE_API_KEY", ""),
+ request_limit=True,
+ rpm=RPM(env.get("RPM", 1000)),
+ tpm=TPM(env.get("TPM", 50000)),
+ )
+
+ graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
+
+ strategy_config = config.get("traverse_strategy", {})
+ graph_gen.traverse_strategy = TraverseStrategy(
+ qa_form=strategy_config.get("qa_form"),
+ expand_method=strategy_config.get("expand_method"),
+ bidirectional=strategy_config.get("bidirectional"),
+ max_extra_edges=strategy_config.get("max_extra_edges"),
+ max_tokens=strategy_config.get("max_tokens"),
+ max_depth=strategy_config.get("max_depth"),
+ edge_sampling=strategy_config.get("edge_sampling"),
+ isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
+ loss_strategy=str(strategy_config.get("loss_strategy")),
+ )
+
+ return graph_gen
+
+
+# pylint: disable=too-many-statements
+def run_graphgen(params, progress=gr.Progress()):
+ def sum_tokens(client):
+ return sum(u["total_tokens"] for u in client.token_usage)
+
+ config = {
+ "if_trainee_model": params.if_trainee_model,
+ "input_file": params.input_file,
+ "tokenizer": params.tokenizer,
+ "quiz_samples": params.quiz_samples,
+ "traverse_strategy": {
+ "qa_form": params.qa_form,
+ "bidirectional": params.bidirectional,
+ "expand_method": params.expand_method,
+ "max_extra_edges": params.max_extra_edges,
+ "max_tokens": params.max_tokens,
+ "max_depth": params.max_depth,
+ "edge_sampling": params.edge_sampling,
+ "isolated_node_strategy": params.isolated_node_strategy,
+ "loss_strategy": params.loss_strategy,
+ },
+ "chunk_size": params.chunk_size,
+ }
+
+ env = {
+ "SYNTHESIZER_BASE_URL": params.synthesizer_url,
+ "SYNTHESIZER_MODEL": params.synthesizer_model,
+ "TRAINEE_BASE_URL": params.trainee_url,
+ "TRAINEE_MODEL": params.trainee_model,
+ "SYNTHESIZER_API_KEY": params.api_key,
+ "TRAINEE_API_KEY": params.trainee_api_key,
+ "RPM": params.rpm,
+ "TPM": params.tpm,
+ }
+
+ # Test API connection
+ test_api_connection(
+ env["SYNTHESIZER_BASE_URL"],
+ env["SYNTHESIZER_API_KEY"],
+ env["SYNTHESIZER_MODEL"],
+ )
+ if config["if_trainee_model"]:
+ test_api_connection(
+ env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
+ )
+
+ # Initialize GraphGen
+ graph_gen = init_graph_gen(config, env)
+ graph_gen.clear()
+
+ graph_gen.progress_bar = progress
+
+ try:
+ # Load input data
+ file = config["input_file"]
+ if isinstance(file, list):
+ file = file[0]
+
+ data = []
+
+ if file.endswith(".jsonl"):
+ data_type = "raw"
+ with open(file, "r", encoding="utf-8") as f:
+ data.extend(json.loads(line) for line in f)
+ elif file.endswith(".json"):
+ data_type = "chunked"
+ with open(file, "r", encoding="utf-8") as f:
+ data.extend(json.load(f))
+ elif file.endswith(".txt"):
+ # 读取文件后根据chunk_size转成raw格式的数据
+ data_type = "raw"
+ content = ""
+ with open(file, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+ for line in lines:
+ content += line.strip() + " "
+ size = int(config.get("chunk_size", 512))
+ chunks = [content[i : i + size] for i in range(0, len(content), size)]
+ data.extend([{"content": chunk} for chunk in chunks])
+ else:
+ raise ValueError(f"Unsupported file type: {file}")
+
+ # Process the data
+ graph_gen.insert(data, data_type)
+
+ if config["if_trainee_model"]:
+ # Generate quiz
+ graph_gen.quiz(max_samples=config["quiz_samples"])
+
+ # Judge statements
+ graph_gen.judge()
+ else:
+ graph_gen.traverse_strategy.edge_sampling = "random"
+ # Skip judge statements
+ graph_gen.judge(skip=True)
+
+ # Traverse graph
+ graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
+
+ # Save output
+ output_data = graph_gen.qa_storage.data
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
+ ) as tmpfile:
+ json.dump(output_data, tmpfile, ensure_ascii=False)
+ output_file = tmpfile.name
+
+ synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
+ trainee_tokens = (
+ sum_tokens(graph_gen.trainee_llm_client)
+ if config["if_trainee_model"]
+ else 0
+ )
+ total_tokens = synthesizer_tokens + trainee_tokens
+
+ data_frame = params.token_counter
+ try:
+ _update_data = [
+ [data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
+ ]
+ new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
+ data_frame = new_df
+
+ except Exception as e:
+ raise gr.Error(f"DataFrame operation error: {str(e)}")
+
+ return output_file, gr.DataFrame(
+ label="Token Stats",
+ headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
+ datatype="str",
+ interactive=False,
+ value=data_frame,
+ visible=True,
+ wrap=True,
+ )
+
+ except Exception as e: # pylint: disable=broad-except
+ raise gr.Error(f"Error occurred: {str(e)}")
+
+ finally:
+ # Clean up workspace
+ cleanup_workspace(graph_gen.working_dir)
+
+
+with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
+ # Header
+ gr.Image(
+ value=os.path.join(root_dir, "resources", "images", "logo.png"),
+ label="GraphGen Banner",
+ elem_id="banner",
+ interactive=False,
+ container=False,
+ show_download_button=False,
+ show_fullscreen_button=False,
+ )
+ lang_btn = gr.Radio(
+ choices=[
+ ("English", "en"),
+ ("简体中文", "zh"),
+ ],
+ value="en",
+ # label=_("Language"),
+ render=False,
+ container=False,
+ elem_classes=["center-row"],
+ )
+
+ gr.HTML(
+ """
+
+ """
+ )
+ with Translate(
+ os.path.join(root_dir, "webui", "translation.json"),
+ lang_btn,
+ placeholder_langs=["en", "zh"],
+ persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
+ ):
+ lang_btn.render()
+
+ gr.Markdown(
+ value="# "
+ + _("Title")
+ + "\n\n"
+ + "### [GraphGen](https://github.com/open-sciencelab/GraphGen) "
+ + _("Intro")
+ )
+
+ if_trainee_model = gr.Checkbox(
+ label=_("Use Trainee Model"), value=False, interactive=True
+ )
+
+ with gr.Accordion(label=_("Model Config"), open=False):
+ synthesizer_url = gr.Textbox(
+ label="Synthesizer URL",
+ value="https://api.siliconflow.cn/v1",
+ info=_("Synthesizer URL Info"),
+ interactive=True,
+ )
+ synthesizer_model = gr.Textbox(
+ label="Synthesizer Model",
+ value="Qwen/Qwen2.5-7B-Instruct",
+ info=_("Synthesizer Model Info"),
+ interactive=True,
+ )
+ trainee_url = gr.Textbox(
+ label="Trainee URL",
+ value="https://api.siliconflow.cn/v1",
+ info=_("Trainee URL Info"),
+ interactive=True,
+ visible=if_trainee_model.value is True,
+ )
+ trainee_model = gr.Textbox(
+ label="Trainee Model",
+ value="Qwen/Qwen2.5-7B-Instruct",
+ info=_("Trainee Model Info"),
+ interactive=True,
+ visible=if_trainee_model.value is True,
+ )
+ trainee_api_key = gr.Textbox(
+ label=_("SiliconFlow Token for Trainee Model"),
+ type="password",
+ value="",
+ info="https://cloud.siliconflow.cn/account/ak",
+ visible=if_trainee_model.value is True,
+ )
+
+ with gr.Accordion(label=_("Generation Config"), open=False):
+ chunk_size = gr.Slider(
+ label="Chunk Size",
+ minimum=256,
+ maximum=4096,
+ value=512,
+ step=256,
+ interactive=True,
+ )
+ tokenizer = gr.Textbox(
+ label="Tokenizer", value="cl100k_base", interactive=True
+ )
+ qa_form = gr.Radio(
+ choices=["atomic", "multi_hop", "aggregated"],
+ label="QA Form",
+ value="aggregated",
+ interactive=True,
+ )
+ quiz_samples = gr.Number(
+ label="Quiz Samples",
+ value=2,
+ minimum=1,
+ interactive=True,
+ visible=if_trainee_model.value is True,
+ )
+ bidirectional = gr.Checkbox(
+ label="Bidirectional", value=True, interactive=True
+ )
+
+ expand_method = gr.Radio(
+ choices=["max_width", "max_tokens"],
+ label="Expand Method",
+ value="max_tokens",
+ interactive=True,
+ )
+ max_extra_edges = gr.Slider(
+ minimum=1,
+ maximum=10,
+ value=5,
+ label="Max Extra Edges",
+ step=1,
+ interactive=True,
+ visible=expand_method.value == "max_width",
+ )
+ max_tokens = gr.Slider(
+ minimum=64,
+ maximum=1024,
+ value=256,
+ label="Max Tokens",
+ step=64,
+ interactive=True,
+ visible=(expand_method.value != "max_width"),
+ )
+
+ max_depth = gr.Slider(
+ minimum=1,
+ maximum=5,
+ value=2,
+ label="Max Depth",
+ step=1,
+ interactive=True,
+ )
+ edge_sampling = gr.Radio(
+ choices=["max_loss", "min_loss", "random"],
+ label="Edge Sampling",
+ value="max_loss",
+ interactive=True,
+ visible=if_trainee_model.value is True,
+ )
+ isolated_node_strategy = gr.Radio(
+ choices=["add", "ignore"],
+ label="Isolated Node Strategy",
+ value="ignore",
+ interactive=True,
+ )
+ loss_strategy = gr.Radio(
+ choices=["only_edge", "both"],
+ label="Loss Strategy",
+ value="only_edge",
+ interactive=True,
+ )
+
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=3):
+ api_key = gr.Textbox(
+ label=_("SiliconFlow Token"),
+ type="password",
+ value="",
+ info="https://cloud.siliconflow.cn/account/ak",
+ )
+ with gr.Column(scale=1):
+ test_connection_btn = gr.Button(_("Test Connection"))
+
+ with gr.Blocks():
+ with gr.Row(equal_height=True):
+ with gr.Column():
+ rpm = gr.Slider(
+ label="RPM",
+ minimum=10,
+ maximum=10000,
+ value=1000,
+ step=100,
+ interactive=True,
+ visible=True,
+ )
+ with gr.Column():
+ tpm = gr.Slider(
+ label="TPM",
+ minimum=5000,
+ maximum=5000000,
+ value=50000,
+ step=1000,
+ interactive=True,
+ visible=True,
+ )
+
+ with gr.Blocks():
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=1):
+ upload_file = gr.File(
+ label=_("Upload File"),
+ file_count="single",
+ file_types=[".txt", ".json", ".jsonl"],
+ interactive=True,
+ )
+ examples_dir = os.path.join(root_dir, "webui", "examples")
+ gr.Examples(
+ examples=[
+ [os.path.join(examples_dir, "txt_demo.txt")],
+ [os.path.join(examples_dir, "raw_demo.jsonl")],
+ [os.path.join(examples_dir, "chunked_demo.json")],
+ ],
+ inputs=upload_file,
+ label=_("Example Files"),
+ examples_per_page=3,
+ )
+ with gr.Column(scale=1):
+ output = gr.File(
+ label="Output(See Github FAQ)",
+ file_count="single",
+ interactive=False,
+ )
+
+ with gr.Blocks():
+ token_counter = gr.DataFrame(
+ label="Token Stats",
+ headers=[
+ "Source Text Token Count",
+ "Estimated Token Usage",
+ "Token Used",
+ ],
+ datatype="str",
+ interactive=False,
+ visible=False,
+ wrap=True,
+ )
+
+ submit_btn = gr.Button(_("Run GraphGen"))
+
+ # Test Connection
+ test_connection_btn.click(
+ test_api_connection,
+ inputs=[synthesizer_url, api_key, synthesizer_model],
+ outputs=[],
+ )
+
+ if if_trainee_model.value:
+ test_connection_btn.click(
+ test_api_connection,
+ inputs=[trainee_url, api_key, trainee_model],
+ outputs=[],
+ )
+
+ expand_method.change(
+ lambda method: (
+ gr.update(visible=method == "max_width"),
+ gr.update(visible=method != "max_width"),
+ ),
+ inputs=expand_method,
+ outputs=[max_extra_edges, max_tokens],
+ )
+
+ if_trainee_model.change(
+ lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
+ inputs=if_trainee_model,
+ outputs=[
+ trainee_url,
+ trainee_model,
+ quiz_samples,
+ edge_sampling,
+ trainee_api_key,
+ ],
+ )
+
+ upload_file.change(
+ lambda x: (gr.update(visible=True)),
+ inputs=[upload_file],
+ outputs=[token_counter],
+ ).then(
+ count_tokens,
+ inputs=[upload_file, tokenizer, token_counter],
+ outputs=[token_counter],
+ )
+
+ # run GraphGen
+ submit_btn.click(
+ lambda x: (gr.update(visible=False)),
+ inputs=[token_counter],
+ outputs=[token_counter],
+ )
+
+ submit_btn.click(
+ lambda *args: run_graphgen(
+ GraphGenParams(
+ if_trainee_model=args[0],
+ input_file=args[1],
+ tokenizer=args[2],
+ qa_form=args[3],
+ bidirectional=args[4],
+ expand_method=args[5],
+ max_extra_edges=args[6],
+ max_tokens=args[7],
+ max_depth=args[8],
+ edge_sampling=args[9],
+ isolated_node_strategy=args[10],
+ loss_strategy=args[11],
+ synthesizer_url=args[12],
+ synthesizer_model=args[13],
+ trainee_model=args[14],
+ api_key=args[15],
+ chunk_size=args[16],
+ rpm=args[17],
+ tpm=args[18],
+ quiz_samples=args[19],
+ trainee_url=args[20],
+ trainee_api_key=args[21],
+ token_counter=args[22],
+ )
+ ),
+ inputs=[
+ if_trainee_model,
+ upload_file,
+ tokenizer,
+ qa_form,
+ bidirectional,
+ expand_method,
+ max_extra_edges,
+ max_tokens,
+ max_depth,
+ edge_sampling,
+ isolated_node_strategy,
+ loss_strategy,
+ synthesizer_url,
+ synthesizer_model,
+ trainee_model,
+ api_key,
+ chunk_size,
+ rpm,
+ tpm,
+ quiz_samples,
+ trainee_url,
+ trainee_api_key,
+ token_counter,
+ ],
+ outputs=[output, token_counter],
+ )
+
+if __name__ == "__main__":
+ demo.queue(api_open=False, default_concurrency_limit=2)
+ demo.launch(server_name="0.0.0.0")
diff --git a/webui/translation.json b/webui/translation.json
index fef5d57976af98effa50a7e9358210ebdcd80b1e..14f420d0a191d02b4438ac623d06e263f9fa2b9e 100644
--- a/webui/translation.json
+++ b/webui/translation.json
@@ -1,36 +1,38 @@
{
"en": {
"Title": "✨Easy-to-use LLM Training Data Generation Framework✨",
+ "### [GraphGen](https://github.com/open-sciencelab/GraphGen) ": "### [GraphGen](https://github.com/open-sciencelab/GraphGen) ",
"Intro": "is a framework for synthetic data generation guided by knowledge graphs, designed to tackle challenges for knowledge-intensive QA generation. \n\nBy uploading your text chunks (such as knowledge in agriculture, healthcare, or marine science) and filling in the LLM API key, you can generate the training data required by **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)** and **[xtuner](https://github.com/InternLM/xtuner)** online. We will automatically delete user information after completion.",
"Use Trainee Model": "Use Trainee Model to identify knowledge blind spots, please keep disable for SiliconCloud",
"Synthesizer URL Info": "Base URL for the Synthesizer Model API, use SiliconFlow as default",
- "Trainee URL Info": "Base URL for the Trainee Model API, use SiliconFlow as default",
"Synthesizer Model Info": "Model for constructing KGs and generating QAs",
+ "Trainee URL Info": "Base URL for the Trainee Model API, use SiliconFlow as default",
"Trainee Model Info": "Model for training",
+ "SiliconFlow Token for Trainee Model": "SiliconFlow API Key for Trainee Model",
"Model Config": "Model Configuration",
"Generation Config": "Generation Config",
- "SiliconCloud Token": "SiliconCloud API Key",
- "SiliconCloud Token for Trainee Model": "SiliconCloud API Key for Trainee Model",
+ "SiliconFlow Token": "SiliconFlow API Key",
"Test Connection": "Test Connection",
- "Run GraphGen": "Run GraphGen",
"Upload File": "Upload File",
- "Example Files": "Example Files"
+ "Example Files": "Example Files",
+ "Run GraphGen": "Run GraphGen"
},
"zh": {
"Title": "✨开箱即用的LLM训练数据生成框架✨",
+ "### [GraphGen](https://github.com/open-sciencelab/GraphGen) ": "### [GraphGen](https://github.com/open-sciencelab/GraphGen) ",
"Intro": "是一个基于知识图谱的数据合成框架,旨在知识密集型任务中生成问答。\n\n 上传你的文本块(如农业、医疗、海洋知识),填写 LLM api key,即可在线生成 **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)**、**[xtuner](https://github.com/InternLM/xtuner)** 所需训练数据。结束后我们将自动删除用户信息。",
"Use Trainee Model": "使用Trainee Model来识别知识盲区,使用硅基流动时请保持禁用",
"Synthesizer URL Info": "调用合成模型API的URL,默认使用硅基流动",
- "Trainee URL Info": "调用学生模型API的URL,默认使用硅基流动",
"Synthesizer Model Info": "用于构建知识图谱和生成问答的模型",
+ "Trainee URL Info": "调用学生模型API的URL,默认使用硅基流动",
"Trainee Model Info": "用于训练的模型",
+ "SiliconFlow Token for Trainee Model": "SiliconFlow Token for Trainee Model",
"Model Config": "模型配置",
"Generation Config": "生成配置",
- "SiliconCloud Token": "硅基流动 API Key",
- "SiliconCloud Token for Trainee Model": "硅基流动 API Key (学生模型)",
+ "SiliconFlow Token": "SiliconFlow Token",
"Test Connection": "测试接口",
- "Run GraphGen": "运行GraphGen",
"Upload File": "上传文件",
- "Example Files": "示例文件"
+ "Example Files": "示例文件",
+ "Run GraphGen": "运行GraphGen"
}
}
\ No newline at end of file