Spaces:
Runtime error
Runtime error
github-actions[bot]
commited on
Commit
·
fb9c306
1
Parent(s):
e453a65
Auto-sync from demo at Thu Aug 28 09:22:58 UTC 2025
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env.example +0 -6
- .gitattributes +0 -35
- .gitignore +0 -179
- README.md +0 -14
- app.py +281 -214
- graphgen/configs/README.md +1 -0
- graphgen/configs/aggregated_config.yaml +21 -0
- graphgen/configs/atomic_config.yaml +21 -0
- graphgen/configs/config.yaml.example +0 -16
- graphgen/configs/cot_config.yaml +13 -0
- graphgen/configs/graphgen_config.yaml +0 -16
- graphgen/configs/multi_hop_config.yaml +21 -0
- graphgen/generate.py +64 -62
- graphgen/graphgen.py +232 -97
- graphgen/models/__init__.py +18 -14
- graphgen/models/community/__init__.py +0 -0
- graphgen/models/community/community_detector.py +95 -0
- graphgen/models/llm/openai_model.py +58 -33
- graphgen/models/search/db/__init__.py +0 -0
- graphgen/models/search/db/uniprot_search.py +64 -0
- graphgen/models/search/kg/__init__.py +0 -0
- graphgen/models/search/{wiki_search.py → kg/wiki_search.py} +4 -3
- graphgen/models/search/web/__init__.py +0 -0
- graphgen/models/search/web/bing_search.py +43 -0
- graphgen/models/search/web/google_search.py +45 -0
- graphgen/models/storage/base_storage.py +25 -4
- graphgen/models/storage/json_storage.py +39 -3
- graphgen/models/vis/__init__.py +0 -0
- graphgen/models/vis/community_visualizer.py +48 -0
- graphgen/operators/__init__.py +13 -7
- graphgen/operators/generate/__init__.py +0 -0
- graphgen/operators/generate/generate_cot.py +117 -0
- graphgen/operators/judge.py +48 -87
- graphgen/operators/kg/__init__.py +0 -0
- graphgen/operators/{extract_kg.py → kg/extract_kg.py} +48 -29
- graphgen/operators/{merge_kg.py → kg/merge_kg.py} +38 -41
- graphgen/operators/{split_graph.py → kg/split_kg.py} +92 -44
- graphgen/operators/preprocess/__init__.py +0 -0
- graphgen/operators/{resolute_coreference.py → preprocess/resolute_coreference.py} +8 -8
- graphgen/operators/search/__init__.py +0 -0
- graphgen/operators/search/db/__init__.py +0 -0
- graphgen/operators/search/db/search_uniprot.py +0 -0
- graphgen/operators/search/kg/__init__.py +0 -0
- graphgen/operators/search/kg/search_wikipedia.py +58 -0
- graphgen/operators/search/search_all.py +82 -0
- graphgen/operators/search/web/__init__.py +0 -0
- graphgen/operators/search/web/search_bing.py +53 -0
- graphgen/operators/search/web/search_google.py +49 -0
- graphgen/operators/search_wikipedia.py +0 -71
- graphgen/operators/traverse_graph.py +199 -148
.env.example
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
SYNTHESIZER_MODEL=
|
2 |
-
SYNTHESIZER_BASE_URL=
|
3 |
-
SYNTHESIZER_API_KEY=
|
4 |
-
TRAINEE_MODEL=
|
5 |
-
TRAINEE_BASE_URL=
|
6 |
-
TRAINEE_API_KEY=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitattributes
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
DELETED
@@ -1,179 +0,0 @@
|
|
1 |
-
# Byte-compiled / optimized / DLL files
|
2 |
-
__pycache__/
|
3 |
-
*.py[cod]
|
4 |
-
*$py.class
|
5 |
-
|
6 |
-
# C extensions
|
7 |
-
*.so
|
8 |
-
|
9 |
-
# Distribution / packaging
|
10 |
-
.Python
|
11 |
-
build/
|
12 |
-
develop-eggs/
|
13 |
-
dist/
|
14 |
-
downloads/
|
15 |
-
eggs/
|
16 |
-
.eggs/
|
17 |
-
lib/
|
18 |
-
lib64/
|
19 |
-
parts/
|
20 |
-
sdist/
|
21 |
-
var/
|
22 |
-
wheels/
|
23 |
-
share/python-wheels/
|
24 |
-
*.egg-info/
|
25 |
-
.installed.cfg
|
26 |
-
*.egg
|
27 |
-
MANIFEST
|
28 |
-
|
29 |
-
# PyInstaller
|
30 |
-
# Usually these files are written by a python script from a template
|
31 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
-
*.manifest
|
33 |
-
*.spec
|
34 |
-
|
35 |
-
# Installer logs
|
36 |
-
pip-log.txt
|
37 |
-
pip-delete-this-directory.txt
|
38 |
-
|
39 |
-
# Unit test / coverage reports
|
40 |
-
htmlcov/
|
41 |
-
.tox/
|
42 |
-
.nox/
|
43 |
-
.coverage
|
44 |
-
.coverage.*
|
45 |
-
.cache
|
46 |
-
nosetests.xml
|
47 |
-
coverage.xml
|
48 |
-
*.cover
|
49 |
-
*.py,cover
|
50 |
-
.hypothesis/
|
51 |
-
.pytest_cache/
|
52 |
-
cover/
|
53 |
-
|
54 |
-
# Translations
|
55 |
-
*.mo
|
56 |
-
*.pot
|
57 |
-
|
58 |
-
# Django stuff:
|
59 |
-
*.log
|
60 |
-
local_settings.py
|
61 |
-
db.sqlite3
|
62 |
-
db.sqlite3-journal
|
63 |
-
|
64 |
-
# Flask stuff:
|
65 |
-
instance/
|
66 |
-
.webassets-cache
|
67 |
-
|
68 |
-
# Scrapy stuff:
|
69 |
-
.scrapy
|
70 |
-
|
71 |
-
# Sphinx documentation
|
72 |
-
docs/_build/
|
73 |
-
|
74 |
-
# PyBuilder
|
75 |
-
.pybuilder/
|
76 |
-
target/
|
77 |
-
|
78 |
-
# Jupyter Notebook
|
79 |
-
.ipynb_checkpoints
|
80 |
-
|
81 |
-
# IPython
|
82 |
-
profile_default/
|
83 |
-
ipython_config.py
|
84 |
-
|
85 |
-
# pyenv
|
86 |
-
# For a library or package, you might want to ignore these files since the code is
|
87 |
-
# intended to run in multiple environments; otherwise, check them in:
|
88 |
-
# .python-version
|
89 |
-
|
90 |
-
# pipenv
|
91 |
-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
-
# install all needed dependencies.
|
95 |
-
#Pipfile.lock
|
96 |
-
|
97 |
-
# UV
|
98 |
-
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
-
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
-
# commonly ignored for libraries.
|
101 |
-
#uv.lock
|
102 |
-
|
103 |
-
# poetry
|
104 |
-
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
-
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
-
# commonly ignored for libraries.
|
107 |
-
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
-
#poetry.lock
|
109 |
-
|
110 |
-
# pdm
|
111 |
-
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
-
#pdm.lock
|
113 |
-
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
-
# in version control.
|
115 |
-
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
-
.pdm.toml
|
117 |
-
.pdm-python
|
118 |
-
.pdm-build/
|
119 |
-
|
120 |
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
-
__pypackages__/
|
122 |
-
|
123 |
-
# Celery stuff
|
124 |
-
celerybeat-schedule
|
125 |
-
celerybeat.pid
|
126 |
-
|
127 |
-
# SageMath parsed files
|
128 |
-
*.sage.py
|
129 |
-
|
130 |
-
# Environments
|
131 |
-
.env
|
132 |
-
.venv
|
133 |
-
env/
|
134 |
-
venv/
|
135 |
-
ENV/
|
136 |
-
env.bak/
|
137 |
-
venv.bak/
|
138 |
-
|
139 |
-
# Spyder project settings
|
140 |
-
.spyderproject
|
141 |
-
.spyproject
|
142 |
-
|
143 |
-
# Rope project settings
|
144 |
-
.ropeproject
|
145 |
-
|
146 |
-
# mkdocs documentation
|
147 |
-
/site
|
148 |
-
|
149 |
-
# mypy
|
150 |
-
.mypy_cache/
|
151 |
-
.dmypy.json
|
152 |
-
dmypy.json
|
153 |
-
|
154 |
-
# Pyre type checker
|
155 |
-
.pyre/
|
156 |
-
|
157 |
-
# pytype static type analyzer
|
158 |
-
.pytype/
|
159 |
-
|
160 |
-
# Cython debug symbols
|
161 |
-
cython_debug/
|
162 |
-
|
163 |
-
# PyCharm
|
164 |
-
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
-
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
-
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
-
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
-
.idea/
|
169 |
-
|
170 |
-
# Ruff stuff:
|
171 |
-
.ruff_cache/
|
172 |
-
|
173 |
-
# PyPI configuration file
|
174 |
-
.pypirc
|
175 |
-
|
176 |
-
cache
|
177 |
-
*.pyc
|
178 |
-
*.html
|
179 |
-
.gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: GraphGen
|
3 |
-
emoji: 🐠
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: blue
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.32.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
short_description: A framework for synthetic data generation based on KG.
|
12 |
-
---
|
13 |
-
|
14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,20 +1,19 @@
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
import json
|
4 |
import tempfile
|
5 |
|
6 |
-
import pandas as pd
|
7 |
import gradio as gr
|
8 |
-
|
9 |
-
from
|
10 |
-
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from
|
14 |
-
from
|
15 |
|
16 |
# pylint: disable=wrong-import-position
|
17 |
-
root_dir = os.path.dirname(os.path.abspath(__file__))
|
18 |
sys.path.append(root_dir)
|
19 |
|
20 |
from graphgen.graphgen import GraphGen
|
@@ -22,7 +21,6 @@ from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
|
|
22 |
from graphgen.models.llm.limitter import RPM, TPM
|
23 |
from graphgen.utils import set_logger
|
24 |
|
25 |
-
|
26 |
css = """
|
27 |
.center-row {
|
28 |
display: flex;
|
@@ -37,9 +35,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
37 |
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
|
38 |
|
39 |
set_logger(log_file, if_stream=False)
|
40 |
-
graph_gen = GraphGen(
|
41 |
-
working_dir=working_dir
|
42 |
-
)
|
43 |
|
44 |
# Set up LLM clients
|
45 |
graph_gen.synthesizer_llm_client = OpenAIModel(
|
@@ -47,8 +43,8 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
47 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
48 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
49 |
request_limit=True,
|
50 |
-
rpm=
|
51 |
-
tpm=
|
52 |
)
|
53 |
|
54 |
graph_gen.trainee_llm_client = OpenAIModel(
|
@@ -56,16 +52,15 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
56 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
57 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
58 |
request_limit=True,
|
59 |
-
rpm=
|
60 |
-
tpm=
|
61 |
)
|
62 |
|
63 |
-
graph_gen.tokenizer_instance = Tokenizer(
|
64 |
-
config.get("tokenizer", "cl100k_base"))
|
65 |
|
66 |
strategy_config = config.get("traverse_strategy", {})
|
67 |
graph_gen.traverse_strategy = TraverseStrategy(
|
68 |
-
qa_form=
|
69 |
expand_method=strategy_config.get("expand_method"),
|
70 |
bidirectional=strategy_config.get("bidirectional"),
|
71 |
max_extra_edges=strategy_config.get("max_extra_edges"),
|
@@ -73,11 +68,12 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
73 |
max_depth=strategy_config.get("max_depth"),
|
74 |
edge_sampling=strategy_config.get("edge_sampling"),
|
75 |
isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
|
76 |
-
loss_strategy=str(strategy_config.get("loss_strategy"))
|
77 |
)
|
78 |
|
79 |
return graph_gen
|
80 |
|
|
|
81 |
# pylint: disable=too-many-statements
|
82 |
def run_graphgen(params, progress=gr.Progress()):
|
83 |
def sum_tokens(client):
|
@@ -87,10 +83,9 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
87 |
"if_trainee_model": params.if_trainee_model,
|
88 |
"input_file": params.input_file,
|
89 |
"tokenizer": params.tokenizer,
|
90 |
-
"qa_form": params.qa_form,
|
91 |
-
"web_search": False,
|
92 |
"quiz_samples": params.quiz_samples,
|
93 |
"traverse_strategy": {
|
|
|
94 |
"bidirectional": params.bidirectional,
|
95 |
"expand_method": params.expand_method,
|
96 |
"max_extra_edges": params.max_extra_edges,
|
@@ -98,7 +93,7 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
98 |
"max_depth": params.max_depth,
|
99 |
"edge_sampling": params.edge_sampling,
|
100 |
"isolated_node_strategy": params.isolated_node_strategy,
|
101 |
-
"loss_strategy": params.loss_strategy
|
102 |
},
|
103 |
"chunk_size": params.chunk_size,
|
104 |
}
|
@@ -115,11 +110,15 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
115 |
}
|
116 |
|
117 |
# Test API connection
|
118 |
-
test_api_connection(
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
123 |
|
124 |
# Initialize GraphGen
|
125 |
graph_gen = init_graph_gen(config, env)
|
@@ -129,7 +128,7 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
129 |
|
130 |
try:
|
131 |
# Load input data
|
132 |
-
file = config[
|
133 |
if isinstance(file, list):
|
134 |
file = file[0]
|
135 |
|
@@ -137,24 +136,22 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
137 |
|
138 |
if file.endswith(".jsonl"):
|
139 |
data_type = "raw"
|
140 |
-
with open(file, "r", encoding=
|
141 |
data.extend(json.loads(line) for line in f)
|
142 |
elif file.endswith(".json"):
|
143 |
data_type = "chunked"
|
144 |
-
with open(file, "r", encoding=
|
145 |
data.extend(json.load(f))
|
146 |
elif file.endswith(".txt"):
|
147 |
# 读取文件后根据chunk_size转成raw格式的数据
|
148 |
data_type = "raw"
|
149 |
content = ""
|
150 |
-
with open(file, "r", encoding=
|
151 |
lines = f.readlines()
|
152 |
for line in lines:
|
153 |
content += line.strip() + " "
|
154 |
size = int(config.get("chunk_size", 512))
|
155 |
-
chunks = [
|
156 |
-
content[i:i + size] for i in range(0, len(content), size)
|
157 |
-
]
|
158 |
data.extend([{"content": chunk} for chunk in chunks])
|
159 |
else:
|
160 |
raise ValueError(f"Unsupported file type: {file}")
|
@@ -162,9 +159,9 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
162 |
# Process the data
|
163 |
graph_gen.insert(data, data_type)
|
164 |
|
165 |
-
if config[
|
166 |
# Generate quiz
|
167 |
-
graph_gen.quiz(max_samples=config[
|
168 |
|
169 |
# Judge statements
|
170 |
graph_gen.judge()
|
@@ -174,47 +171,44 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
174 |
graph_gen.judge(skip=True)
|
175 |
|
176 |
# Traverse graph
|
177 |
-
graph_gen.traverse()
|
178 |
|
179 |
# Save output
|
180 |
output_data = graph_gen.qa_storage.data
|
181 |
with tempfile.NamedTemporaryFile(
|
182 |
-
|
183 |
-
|
184 |
-
delete=False,
|
185 |
-
encoding="utf-8") as tmpfile:
|
186 |
json.dump(output_data, tmpfile, ensure_ascii=False)
|
187 |
output_file = tmpfile.name
|
188 |
|
189 |
synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
|
190 |
-
trainee_tokens =
|
|
|
|
|
|
|
|
|
191 |
total_tokens = synthesizer_tokens + trainee_tokens
|
192 |
|
193 |
data_frame = params.token_counter
|
194 |
try:
|
195 |
_update_data = [
|
196 |
-
[
|
197 |
-
data_frame.iloc[0, 0],
|
198 |
-
data_frame.iloc[0, 1],
|
199 |
-
str(total_tokens)
|
200 |
-
]
|
201 |
]
|
202 |
-
new_df = pd.DataFrame(
|
203 |
-
_update_data,
|
204 |
-
columns=data_frame.columns
|
205 |
-
)
|
206 |
data_frame = new_df
|
207 |
|
208 |
except Exception as e:
|
209 |
raise gr.Error(f"DataFrame operation error: {str(e)}")
|
210 |
|
211 |
-
return output_file, gr.DataFrame(
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
218 |
|
219 |
except Exception as e: # pylint: disable=broad-except
|
220 |
raise gr.Error(f"Error occurred: {str(e)}")
|
@@ -223,16 +217,18 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
223 |
# Clean up workspace
|
224 |
cleanup_workspace(graph_gen.working_dir)
|
225 |
|
226 |
-
|
227 |
-
|
228 |
# Header
|
229 |
-
gr.Image(
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
236 |
lang_btn = gr.Radio(
|
237 |
choices=[
|
238 |
("English", "en"),
|
@@ -245,7 +241,8 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
|
|
245 |
elem_classes=["center-row"],
|
246 |
)
|
247 |
|
248 |
-
gr.HTML(
|
|
|
249 |
<div style="display: flex; gap: 8px; margin-left: auto; align-items: center; justify-content: center;">
|
250 |
<a href="https://github.com/open-sciencelab/GraphGen/releases">
|
251 |
<img src="https://img.shields.io/badge/Version-v0.1.0-blue" alt="Version">
|
@@ -260,80 +257,98 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
|
|
260 |
<img src="https://img.shields.io/badge/arXiv-pdf-yellow" alt="arXiv">
|
261 |
</a>
|
262 |
</div>
|
263 |
-
"""
|
|
|
264 |
with Translate(
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
|
270 |
):
|
271 |
lang_btn.render()
|
272 |
|
273 |
gr.Markdown(
|
274 |
-
value
|
275 |
-
|
|
|
|
|
|
|
276 |
)
|
277 |
|
278 |
-
if_trainee_model = gr.Checkbox(
|
279 |
-
|
280 |
-
|
281 |
|
282 |
with gr.Accordion(label=_("Model Config"), open=False):
|
283 |
-
synthesizer_url = gr.Textbox(
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
trainee_model = gr.Textbox(
|
297 |
label="Trainee Model",
|
298 |
value="Qwen/Qwen2.5-7B-Instruct",
|
299 |
info=_("Trainee Model Info"),
|
300 |
interactive=True,
|
301 |
-
visible=if_trainee_model.value is True
|
|
|
302 |
trainee_api_key = gr.Textbox(
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
|
310 |
with gr.Accordion(label=_("Generation Config"), open=False):
|
311 |
-
chunk_size = gr.Slider(
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
max_extra_edges = gr.Slider(
|
338 |
minimum=1,
|
339 |
maximum=10,
|
@@ -341,44 +356,54 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
|
|
341 |
label="Max Extra Edges",
|
342 |
step=1,
|
343 |
interactive=True,
|
344 |
-
visible=expand_method.value == "max_width"
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
|
|
|
|
|
|
|
|
360 |
edge_sampling = gr.Radio(
|
361 |
choices=["max_loss", "min_loss", "random"],
|
362 |
label="Edge Sampling",
|
363 |
value="max_loss",
|
364 |
interactive=True,
|
365 |
-
visible=if_trainee_model.value is True
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
|
|
|
|
|
|
|
|
|
|
374 |
|
375 |
with gr.Row(equal_height=True):
|
376 |
with gr.Column(scale=3):
|
377 |
api_key = gr.Textbox(
|
378 |
-
label=_("
|
379 |
type="password",
|
380 |
value="",
|
381 |
-
info="https://cloud.siliconflow.cn/account/ak"
|
|
|
382 |
with gr.Column(scale=1):
|
383 |
test_connection_btn = gr.Button(_("Test Connection"))
|
384 |
|
@@ -392,7 +417,8 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
|
|
392 |
value=1000,
|
393 |
step=100,
|
394 |
interactive=True,
|
395 |
-
visible=True
|
|
|
396 |
with gr.Column():
|
397 |
tpm = gr.Slider(
|
398 |
label="TPM",
|
@@ -401,8 +427,8 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
|
|
401 |
value=50000,
|
402 |
step=1000,
|
403 |
interactive=True,
|
404 |
-
visible=True
|
405 |
-
|
406 |
|
407 |
with gr.Blocks():
|
408 |
with gr.Row(equal_height=True):
|
@@ -413,15 +439,17 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
|
|
413 |
file_types=[".txt", ".json", ".jsonl"],
|
414 |
interactive=True,
|
415 |
)
|
416 |
-
examples_dir = os.path.join(root_dir,
|
417 |
-
gr.Examples(
|
418 |
-
[
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
|
|
|
|
425 |
with gr.Column(scale=1):
|
426 |
output = gr.File(
|
427 |
label="Output(See Github FAQ)",
|
@@ -430,12 +458,18 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
|
|
430 |
)
|
431 |
|
432 |
with gr.Blocks():
|
433 |
-
token_counter = gr.DataFrame(
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
|
440 |
submit_btn = gr.Button(_("Run GraphGen"))
|
441 |
|
@@ -443,23 +477,36 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
|
|
443 |
test_connection_btn.click(
|
444 |
test_api_connection,
|
445 |
inputs=[synthesizer_url, api_key, synthesizer_model],
|
446 |
-
outputs=[]
|
|
|
447 |
|
448 |
if if_trainee_model.value:
|
449 |
-
test_connection_btn.click(
|
450 |
-
|
451 |
-
|
|
|
|
|
452 |
|
453 |
-
expand_method.change(
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
|
|
|
|
|
|
458 |
|
459 |
if_trainee_model.change(
|
460 |
lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
|
461 |
inputs=if_trainee_model,
|
462 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
|
464 |
upload_file.change(
|
465 |
lambda x: (gr.update(visible=True)),
|
@@ -479,41 +526,61 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
|
|
479 |
)
|
480 |
|
481 |
submit_btn.click(
|
482 |
-
lambda *args: run_graphgen(
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
|
|
|
|
507 |
inputs=[
|
508 |
-
if_trainee_model,
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
],
|
514 |
outputs=[output, token_counter],
|
515 |
)
|
516 |
|
517 |
if __name__ == "__main__":
|
518 |
demo.queue(api_open=False, default_concurrency_limit=2)
|
519 |
-
demo.launch(server_name=
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
import sys
|
|
|
4 |
import tempfile
|
5 |
|
|
|
6 |
import gradio as gr
|
7 |
+
import pandas as pd
|
8 |
+
from base import GraphGenParams
|
9 |
+
from cache_utils import cleanup_workspace, setup_workspace
|
10 |
+
from count_tokens import count_tokens
|
11 |
+
from gradio_i18n import Translate
|
12 |
+
from gradio_i18n import gettext as _
|
13 |
+
from test_api import test_api_connection
|
14 |
|
15 |
# pylint: disable=wrong-import-position
|
16 |
+
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
17 |
sys.path.append(root_dir)
|
18 |
|
19 |
from graphgen.graphgen import GraphGen
|
|
|
21 |
from graphgen.models.llm.limitter import RPM, TPM
|
22 |
from graphgen.utils import set_logger
|
23 |
|
|
|
24 |
css = """
|
25 |
.center-row {
|
26 |
display: flex;
|
|
|
35 |
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
|
36 |
|
37 |
set_logger(log_file, if_stream=False)
|
38 |
+
graph_gen = GraphGen(working_dir=working_dir)
|
|
|
|
|
39 |
|
40 |
# Set up LLM clients
|
41 |
graph_gen.synthesizer_llm_client = OpenAIModel(
|
|
|
43 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
44 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
45 |
request_limit=True,
|
46 |
+
rpm=RPM(env.get("RPM", 1000)),
|
47 |
+
tpm=TPM(env.get("TPM", 50000)),
|
48 |
)
|
49 |
|
50 |
graph_gen.trainee_llm_client = OpenAIModel(
|
|
|
52 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
53 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
54 |
request_limit=True,
|
55 |
+
rpm=RPM(env.get("RPM", 1000)),
|
56 |
+
tpm=TPM(env.get("TPM", 50000)),
|
57 |
)
|
58 |
|
59 |
+
graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
|
|
60 |
|
61 |
strategy_config = config.get("traverse_strategy", {})
|
62 |
graph_gen.traverse_strategy = TraverseStrategy(
|
63 |
+
qa_form=strategy_config.get("qa_form"),
|
64 |
expand_method=strategy_config.get("expand_method"),
|
65 |
bidirectional=strategy_config.get("bidirectional"),
|
66 |
max_extra_edges=strategy_config.get("max_extra_edges"),
|
|
|
68 |
max_depth=strategy_config.get("max_depth"),
|
69 |
edge_sampling=strategy_config.get("edge_sampling"),
|
70 |
isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
|
71 |
+
loss_strategy=str(strategy_config.get("loss_strategy")),
|
72 |
)
|
73 |
|
74 |
return graph_gen
|
75 |
|
76 |
+
|
77 |
# pylint: disable=too-many-statements
|
78 |
def run_graphgen(params, progress=gr.Progress()):
|
79 |
def sum_tokens(client):
|
|
|
83 |
"if_trainee_model": params.if_trainee_model,
|
84 |
"input_file": params.input_file,
|
85 |
"tokenizer": params.tokenizer,
|
|
|
|
|
86 |
"quiz_samples": params.quiz_samples,
|
87 |
"traverse_strategy": {
|
88 |
+
"qa_form": params.qa_form,
|
89 |
"bidirectional": params.bidirectional,
|
90 |
"expand_method": params.expand_method,
|
91 |
"max_extra_edges": params.max_extra_edges,
|
|
|
93 |
"max_depth": params.max_depth,
|
94 |
"edge_sampling": params.edge_sampling,
|
95 |
"isolated_node_strategy": params.isolated_node_strategy,
|
96 |
+
"loss_strategy": params.loss_strategy,
|
97 |
},
|
98 |
"chunk_size": params.chunk_size,
|
99 |
}
|
|
|
110 |
}
|
111 |
|
112 |
# Test API connection
|
113 |
+
test_api_connection(
|
114 |
+
env["SYNTHESIZER_BASE_URL"],
|
115 |
+
env["SYNTHESIZER_API_KEY"],
|
116 |
+
env["SYNTHESIZER_MODEL"],
|
117 |
+
)
|
118 |
+
if config["if_trainee_model"]:
|
119 |
+
test_api_connection(
|
120 |
+
env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
|
121 |
+
)
|
122 |
|
123 |
# Initialize GraphGen
|
124 |
graph_gen = init_graph_gen(config, env)
|
|
|
128 |
|
129 |
try:
|
130 |
# Load input data
|
131 |
+
file = config["input_file"]
|
132 |
if isinstance(file, list):
|
133 |
file = file[0]
|
134 |
|
|
|
136 |
|
137 |
if file.endswith(".jsonl"):
|
138 |
data_type = "raw"
|
139 |
+
with open(file, "r", encoding="utf-8") as f:
|
140 |
data.extend(json.loads(line) for line in f)
|
141 |
elif file.endswith(".json"):
|
142 |
data_type = "chunked"
|
143 |
+
with open(file, "r", encoding="utf-8") as f:
|
144 |
data.extend(json.load(f))
|
145 |
elif file.endswith(".txt"):
|
146 |
# 读取文件后根据chunk_size转成raw格式的数据
|
147 |
data_type = "raw"
|
148 |
content = ""
|
149 |
+
with open(file, "r", encoding="utf-8") as f:
|
150 |
lines = f.readlines()
|
151 |
for line in lines:
|
152 |
content += line.strip() + " "
|
153 |
size = int(config.get("chunk_size", 512))
|
154 |
+
chunks = [content[i : i + size] for i in range(0, len(content), size)]
|
|
|
|
|
155 |
data.extend([{"content": chunk} for chunk in chunks])
|
156 |
else:
|
157 |
raise ValueError(f"Unsupported file type: {file}")
|
|
|
159 |
# Process the data
|
160 |
graph_gen.insert(data, data_type)
|
161 |
|
162 |
+
if config["if_trainee_model"]:
|
163 |
# Generate quiz
|
164 |
+
graph_gen.quiz(max_samples=config["quiz_samples"])
|
165 |
|
166 |
# Judge statements
|
167 |
graph_gen.judge()
|
|
|
171 |
graph_gen.judge(skip=True)
|
172 |
|
173 |
# Traverse graph
|
174 |
+
graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
|
175 |
|
176 |
# Save output
|
177 |
output_data = graph_gen.qa_storage.data
|
178 |
with tempfile.NamedTemporaryFile(
|
179 |
+
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
|
180 |
+
) as tmpfile:
|
|
|
|
|
181 |
json.dump(output_data, tmpfile, ensure_ascii=False)
|
182 |
output_file = tmpfile.name
|
183 |
|
184 |
synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
|
185 |
+
trainee_tokens = (
|
186 |
+
sum_tokens(graph_gen.trainee_llm_client)
|
187 |
+
if config["if_trainee_model"]
|
188 |
+
else 0
|
189 |
+
)
|
190 |
total_tokens = synthesizer_tokens + trainee_tokens
|
191 |
|
192 |
data_frame = params.token_counter
|
193 |
try:
|
194 |
_update_data = [
|
195 |
+
[data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
|
|
|
|
|
|
|
|
|
196 |
]
|
197 |
+
new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
|
|
|
|
|
|
|
198 |
data_frame = new_df
|
199 |
|
200 |
except Exception as e:
|
201 |
raise gr.Error(f"DataFrame operation error: {str(e)}")
|
202 |
|
203 |
+
return output_file, gr.DataFrame(
|
204 |
+
label="Token Stats",
|
205 |
+
headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
|
206 |
+
datatype="str",
|
207 |
+
interactive=False,
|
208 |
+
value=data_frame,
|
209 |
+
visible=True,
|
210 |
+
wrap=True,
|
211 |
+
)
|
212 |
|
213 |
except Exception as e: # pylint: disable=broad-except
|
214 |
raise gr.Error(f"Error occurred: {str(e)}")
|
|
|
217 |
# Clean up workspace
|
218 |
cleanup_workspace(graph_gen.working_dir)
|
219 |
|
220 |
+
|
221 |
+
with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
222 |
# Header
|
223 |
+
gr.Image(
|
224 |
+
value=os.path.join(root_dir, "resources", "images", "logo.png"),
|
225 |
+
label="GraphGen Banner",
|
226 |
+
elem_id="banner",
|
227 |
+
interactive=False,
|
228 |
+
container=False,
|
229 |
+
show_download_button=False,
|
230 |
+
show_fullscreen_button=False,
|
231 |
+
)
|
232 |
lang_btn = gr.Radio(
|
233 |
choices=[
|
234 |
("English", "en"),
|
|
|
241 |
elem_classes=["center-row"],
|
242 |
)
|
243 |
|
244 |
+
gr.HTML(
|
245 |
+
"""
|
246 |
<div style="display: flex; gap: 8px; margin-left: auto; align-items: center; justify-content: center;">
|
247 |
<a href="https://github.com/open-sciencelab/GraphGen/releases">
|
248 |
<img src="https://img.shields.io/badge/Version-v0.1.0-blue" alt="Version">
|
|
|
257 |
<img src="https://img.shields.io/badge/arXiv-pdf-yellow" alt="arXiv">
|
258 |
</a>
|
259 |
</div>
|
260 |
+
"""
|
261 |
+
)
|
262 |
with Translate(
|
263 |
+
os.path.join(root_dir, "webui", "translation.json"),
|
264 |
+
lang_btn,
|
265 |
+
placeholder_langs=["en", "zh"],
|
266 |
+
persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
|
|
|
267 |
):
|
268 |
lang_btn.render()
|
269 |
|
270 |
gr.Markdown(
|
271 |
+
value="# "
|
272 |
+
+ _("Title")
|
273 |
+
+ "\n\n"
|
274 |
+
+ "### [GraphGen](https://github.com/open-sciencelab/GraphGen) "
|
275 |
+
+ _("Intro")
|
276 |
)
|
277 |
|
278 |
+
if_trainee_model = gr.Checkbox(
|
279 |
+
label=_("Use Trainee Model"), value=False, interactive=True
|
280 |
+
)
|
281 |
|
282 |
with gr.Accordion(label=_("Model Config"), open=False):
|
283 |
+
synthesizer_url = gr.Textbox(
|
284 |
+
label="Synthesizer URL",
|
285 |
+
value="https://api.siliconflow.cn/v1",
|
286 |
+
info=_("Synthesizer URL Info"),
|
287 |
+
interactive=True,
|
288 |
+
)
|
289 |
+
synthesizer_model = gr.Textbox(
|
290 |
+
label="Synthesizer Model",
|
291 |
+
value="Qwen/Qwen2.5-7B-Instruct",
|
292 |
+
info=_("Synthesizer Model Info"),
|
293 |
+
interactive=True,
|
294 |
+
)
|
295 |
+
trainee_url = gr.Textbox(
|
296 |
+
label="Trainee URL",
|
297 |
+
value="https://api.siliconflow.cn/v1",
|
298 |
+
info=_("Trainee URL Info"),
|
299 |
+
interactive=True,
|
300 |
+
visible=if_trainee_model.value is True,
|
301 |
+
)
|
302 |
trainee_model = gr.Textbox(
|
303 |
label="Trainee Model",
|
304 |
value="Qwen/Qwen2.5-7B-Instruct",
|
305 |
info=_("Trainee Model Info"),
|
306 |
interactive=True,
|
307 |
+
visible=if_trainee_model.value is True,
|
308 |
+
)
|
309 |
trainee_api_key = gr.Textbox(
|
310 |
+
label=_("SiliconFlow Token for Trainee Model"),
|
311 |
+
type="password",
|
312 |
+
value="",
|
313 |
+
info="https://cloud.siliconflow.cn/account/ak",
|
314 |
+
visible=if_trainee_model.value is True,
|
315 |
+
)
|
316 |
|
317 |
with gr.Accordion(label=_("Generation Config"), open=False):
|
318 |
+
chunk_size = gr.Slider(
|
319 |
+
label="Chunk Size",
|
320 |
+
minimum=256,
|
321 |
+
maximum=4096,
|
322 |
+
value=512,
|
323 |
+
step=256,
|
324 |
+
interactive=True,
|
325 |
+
)
|
326 |
+
tokenizer = gr.Textbox(
|
327 |
+
label="Tokenizer", value="cl100k_base", interactive=True
|
328 |
+
)
|
329 |
+
qa_form = gr.Radio(
|
330 |
+
choices=["atomic", "multi_hop", "aggregated"],
|
331 |
+
label="QA Form",
|
332 |
+
value="aggregated",
|
333 |
+
interactive=True,
|
334 |
+
)
|
335 |
+
quiz_samples = gr.Number(
|
336 |
+
label="Quiz Samples",
|
337 |
+
value=2,
|
338 |
+
minimum=1,
|
339 |
+
interactive=True,
|
340 |
+
visible=if_trainee_model.value is True,
|
341 |
+
)
|
342 |
+
bidirectional = gr.Checkbox(
|
343 |
+
label="Bidirectional", value=True, interactive=True
|
344 |
+
)
|
345 |
+
|
346 |
+
expand_method = gr.Radio(
|
347 |
+
choices=["max_width", "max_tokens"],
|
348 |
+
label="Expand Method",
|
349 |
+
value="max_tokens",
|
350 |
+
interactive=True,
|
351 |
+
)
|
352 |
max_extra_edges = gr.Slider(
|
353 |
minimum=1,
|
354 |
maximum=10,
|
|
|
356 |
label="Max Extra Edges",
|
357 |
step=1,
|
358 |
interactive=True,
|
359 |
+
visible=expand_method.value == "max_width",
|
360 |
+
)
|
361 |
+
max_tokens = gr.Slider(
|
362 |
+
minimum=64,
|
363 |
+
maximum=1024,
|
364 |
+
value=256,
|
365 |
+
label="Max Tokens",
|
366 |
+
step=64,
|
367 |
+
interactive=True,
|
368 |
+
visible=(expand_method.value != "max_width"),
|
369 |
+
)
|
370 |
+
|
371 |
+
max_depth = gr.Slider(
|
372 |
+
minimum=1,
|
373 |
+
maximum=5,
|
374 |
+
value=2,
|
375 |
+
label="Max Depth",
|
376 |
+
step=1,
|
377 |
+
interactive=True,
|
378 |
+
)
|
379 |
edge_sampling = gr.Radio(
|
380 |
choices=["max_loss", "min_loss", "random"],
|
381 |
label="Edge Sampling",
|
382 |
value="max_loss",
|
383 |
interactive=True,
|
384 |
+
visible=if_trainee_model.value is True,
|
385 |
+
)
|
386 |
+
isolated_node_strategy = gr.Radio(
|
387 |
+
choices=["add", "ignore"],
|
388 |
+
label="Isolated Node Strategy",
|
389 |
+
value="ignore",
|
390 |
+
interactive=True,
|
391 |
+
)
|
392 |
+
loss_strategy = gr.Radio(
|
393 |
+
choices=["only_edge", "both"],
|
394 |
+
label="Loss Strategy",
|
395 |
+
value="only_edge",
|
396 |
+
interactive=True,
|
397 |
+
)
|
398 |
|
399 |
with gr.Row(equal_height=True):
|
400 |
with gr.Column(scale=3):
|
401 |
api_key = gr.Textbox(
|
402 |
+
label=_("SiliconFlow Token"),
|
403 |
type="password",
|
404 |
value="",
|
405 |
+
info="https://cloud.siliconflow.cn/account/ak",
|
406 |
+
)
|
407 |
with gr.Column(scale=1):
|
408 |
test_connection_btn = gr.Button(_("Test Connection"))
|
409 |
|
|
|
417 |
value=1000,
|
418 |
step=100,
|
419 |
interactive=True,
|
420 |
+
visible=True,
|
421 |
+
)
|
422 |
with gr.Column():
|
423 |
tpm = gr.Slider(
|
424 |
label="TPM",
|
|
|
427 |
value=50000,
|
428 |
step=1000,
|
429 |
interactive=True,
|
430 |
+
visible=True,
|
431 |
+
)
|
432 |
|
433 |
with gr.Blocks():
|
434 |
with gr.Row(equal_height=True):
|
|
|
439 |
file_types=[".txt", ".json", ".jsonl"],
|
440 |
interactive=True,
|
441 |
)
|
442 |
+
examples_dir = os.path.join(root_dir, "webui", "examples")
|
443 |
+
gr.Examples(
|
444 |
+
examples=[
|
445 |
+
[os.path.join(examples_dir, "txt_demo.txt")],
|
446 |
+
[os.path.join(examples_dir, "raw_demo.jsonl")],
|
447 |
+
[os.path.join(examples_dir, "chunked_demo.json")],
|
448 |
+
],
|
449 |
+
inputs=upload_file,
|
450 |
+
label=_("Example Files"),
|
451 |
+
examples_per_page=3,
|
452 |
+
)
|
453 |
with gr.Column(scale=1):
|
454 |
output = gr.File(
|
455 |
label="Output(See Github FAQ)",
|
|
|
458 |
)
|
459 |
|
460 |
with gr.Blocks():
|
461 |
+
token_counter = gr.DataFrame(
|
462 |
+
label="Token Stats",
|
463 |
+
headers=[
|
464 |
+
"Source Text Token Count",
|
465 |
+
"Estimated Token Usage",
|
466 |
+
"Token Used",
|
467 |
+
],
|
468 |
+
datatype="str",
|
469 |
+
interactive=False,
|
470 |
+
visible=False,
|
471 |
+
wrap=True,
|
472 |
+
)
|
473 |
|
474 |
submit_btn = gr.Button(_("Run GraphGen"))
|
475 |
|
|
|
477 |
test_connection_btn.click(
|
478 |
test_api_connection,
|
479 |
inputs=[synthesizer_url, api_key, synthesizer_model],
|
480 |
+
outputs=[],
|
481 |
+
)
|
482 |
|
483 |
if if_trainee_model.value:
|
484 |
+
test_connection_btn.click(
|
485 |
+
test_api_connection,
|
486 |
+
inputs=[trainee_url, api_key, trainee_model],
|
487 |
+
outputs=[],
|
488 |
+
)
|
489 |
|
490 |
+
expand_method.change(
|
491 |
+
lambda method: (
|
492 |
+
gr.update(visible=method == "max_width"),
|
493 |
+
gr.update(visible=method != "max_width"),
|
494 |
+
),
|
495 |
+
inputs=expand_method,
|
496 |
+
outputs=[max_extra_edges, max_tokens],
|
497 |
+
)
|
498 |
|
499 |
if_trainee_model.change(
|
500 |
lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
|
501 |
inputs=if_trainee_model,
|
502 |
+
outputs=[
|
503 |
+
trainee_url,
|
504 |
+
trainee_model,
|
505 |
+
quiz_samples,
|
506 |
+
edge_sampling,
|
507 |
+
trainee_api_key,
|
508 |
+
],
|
509 |
+
)
|
510 |
|
511 |
upload_file.change(
|
512 |
lambda x: (gr.update(visible=True)),
|
|
|
526 |
)
|
527 |
|
528 |
submit_btn.click(
|
529 |
+
lambda *args: run_graphgen(
|
530 |
+
GraphGenParams(
|
531 |
+
if_trainee_model=args[0],
|
532 |
+
input_file=args[1],
|
533 |
+
tokenizer=args[2],
|
534 |
+
qa_form=args[3],
|
535 |
+
bidirectional=args[4],
|
536 |
+
expand_method=args[5],
|
537 |
+
max_extra_edges=args[6],
|
538 |
+
max_tokens=args[7],
|
539 |
+
max_depth=args[8],
|
540 |
+
edge_sampling=args[9],
|
541 |
+
isolated_node_strategy=args[10],
|
542 |
+
loss_strategy=args[11],
|
543 |
+
synthesizer_url=args[12],
|
544 |
+
synthesizer_model=args[13],
|
545 |
+
trainee_model=args[14],
|
546 |
+
api_key=args[15],
|
547 |
+
chunk_size=args[16],
|
548 |
+
rpm=args[17],
|
549 |
+
tpm=args[18],
|
550 |
+
quiz_samples=args[19],
|
551 |
+
trainee_url=args[20],
|
552 |
+
trainee_api_key=args[21],
|
553 |
+
token_counter=args[22],
|
554 |
+
)
|
555 |
+
),
|
556 |
inputs=[
|
557 |
+
if_trainee_model,
|
558 |
+
upload_file,
|
559 |
+
tokenizer,
|
560 |
+
qa_form,
|
561 |
+
bidirectional,
|
562 |
+
expand_method,
|
563 |
+
max_extra_edges,
|
564 |
+
max_tokens,
|
565 |
+
max_depth,
|
566 |
+
edge_sampling,
|
567 |
+
isolated_node_strategy,
|
568 |
+
loss_strategy,
|
569 |
+
synthesizer_url,
|
570 |
+
synthesizer_model,
|
571 |
+
trainee_model,
|
572 |
+
api_key,
|
573 |
+
chunk_size,
|
574 |
+
rpm,
|
575 |
+
tpm,
|
576 |
+
quiz_samples,
|
577 |
+
trainee_url,
|
578 |
+
trainee_api_key,
|
579 |
+
token_counter,
|
580 |
],
|
581 |
outputs=[output, token_counter],
|
582 |
)
|
583 |
|
584 |
if __name__ == "__main__":
|
585 |
demo.queue(api_open=False, default_concurrency_limit=2)
|
586 |
+
demo.launch(server_name="0.0.0.0")
|
graphgen/configs/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Configs for GraphGen
|
graphgen/configs/aggregated_config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_data_type: raw # raw, chunked
|
2 |
+
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
3 |
+
output_data_type: aggregated # atomic, aggregated, multi_hop, cot
|
4 |
+
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
5 |
+
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
6 |
+
search: # web search configuration
|
7 |
+
enabled: false # whether to enable web search
|
8 |
+
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
10 |
+
enabled: true
|
11 |
+
quiz_samples: 2 # number of quiz samples to generate
|
12 |
+
re_judge: false # whether to re-judge the existing quiz samples
|
13 |
+
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
14 |
+
bidirectional: true # whether to traverse the graph in both directions
|
15 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
16 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
17 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
18 |
+
max_depth: 5 # maximum depth for graph traversal
|
19 |
+
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
|
20 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
21 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
graphgen/configs/atomic_config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_data_type: raw # raw, chunked
|
2 |
+
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
3 |
+
output_data_type: atomic # atomic, aggregated, multi_hop, cot
|
4 |
+
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
5 |
+
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
6 |
+
search: # web search configuration
|
7 |
+
enabled: false # whether to enable web search
|
8 |
+
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
10 |
+
enabled: true
|
11 |
+
quiz_samples: 2 # number of quiz samples to generate
|
12 |
+
re_judge: false # whether to re-judge the existing quiz samples
|
13 |
+
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
14 |
+
bidirectional: true # whether to traverse the graph in both directions
|
15 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
16 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
17 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
18 |
+
max_depth: 3 # maximum depth for graph traversal
|
19 |
+
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
|
20 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
21 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
graphgen/configs/config.yaml.example
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
data_type: raw
|
2 |
-
input_file: resources/examples/raw_demo.jsonl
|
3 |
-
tokenizer: cl100k_base
|
4 |
-
quiz_samples: 2
|
5 |
-
traverse_strategy:
|
6 |
-
qa_form: atomic
|
7 |
-
bidirectional: true
|
8 |
-
edge_sampling: max_loss
|
9 |
-
expand_method: max_tokens
|
10 |
-
isolated_node_strategy: add
|
11 |
-
max_depth: 2
|
12 |
-
max_extra_edges: 5
|
13 |
-
max_tokens: 256
|
14 |
-
loss_strategy: only_edge
|
15 |
-
web_search: false
|
16 |
-
re_judge: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/configs/cot_config.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_data_type: raw # raw, chunked
|
2 |
+
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
3 |
+
output_data_type: cot # atomic, aggregated, multi_hop, cot
|
4 |
+
output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
|
5 |
+
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
6 |
+
search: # web search configuration
|
7 |
+
enabled: false # whether to enable web search
|
8 |
+
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
method_params:
|
10 |
+
method: leiden
|
11 |
+
max_size: 20 # Maximum size of communities
|
12 |
+
use_lcc: false
|
13 |
+
random_seed: 42
|
graphgen/configs/graphgen_config.yaml
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
data_type: raw
|
2 |
-
input_file: resources/examples/raw_demo.jsonl
|
3 |
-
tokenizer: cl100k_base
|
4 |
-
quiz_samples: 2
|
5 |
-
traverse_strategy:
|
6 |
-
qa_form: aggregated
|
7 |
-
bidirectional: true
|
8 |
-
edge_sampling: max_loss
|
9 |
-
expand_method: max_width
|
10 |
-
isolated_node_strategy: ignore
|
11 |
-
max_depth: 1
|
12 |
-
max_extra_edges: 2
|
13 |
-
max_tokens: 256
|
14 |
-
loss_strategy: only_edge
|
15 |
-
web_search: false
|
16 |
-
re_judge: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/configs/multi_hop_config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_data_type: raw # raw, chunked
|
2 |
+
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
3 |
+
output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
|
4 |
+
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
5 |
+
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
6 |
+
search: # web search configuration
|
7 |
+
enabled: false # whether to enable web search
|
8 |
+
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
10 |
+
enabled: true
|
11 |
+
quiz_samples: 2 # number of quiz samples to generate
|
12 |
+
re_judge: false # whether to re-judge the existing quiz samples
|
13 |
+
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
14 |
+
bidirectional: true # whether to traverse the graph in both directions
|
15 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
16 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
17 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
18 |
+
max_depth: 1 # maximum depth for graph traversal
|
19 |
+
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
|
20 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
21 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
graphgen/generate.py
CHANGED
@@ -1,101 +1,103 @@
|
|
|
|
1 |
import os
|
2 |
-
import json
|
3 |
import time
|
4 |
-
import argparse
|
5 |
from importlib.resources import files
|
|
|
6 |
import yaml
|
7 |
from dotenv import load_dotenv
|
8 |
|
9 |
from .graphgen import GraphGen
|
10 |
-
from .
|
11 |
-
from .utils import set_logger
|
12 |
|
13 |
sys_path = os.path.abspath(os.path.dirname(__file__))
|
14 |
|
15 |
load_dotenv()
|
16 |
|
|
|
17 |
def set_working_dir(folder):
|
18 |
os.makedirs(folder, exist_ok=True)
|
19 |
os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
|
20 |
os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
|
21 |
|
|
|
22 |
def save_config(config_path, global_config):
|
23 |
if not os.path.exists(os.path.dirname(config_path)):
|
24 |
os.makedirs(os.path.dirname(config_path))
|
25 |
-
with open(config_path, "w", encoding=
|
26 |
-
yaml.dump(
|
|
|
|
|
|
|
27 |
|
28 |
def main():
|
29 |
parser = argparse.ArgumentParser()
|
30 |
-
parser.add_argument(
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
40 |
|
41 |
args = parser.parse_args()
|
42 |
|
43 |
working_dir = args.output_dir
|
44 |
set_working_dir(working_dir)
|
45 |
-
unique_id = int(time.time())
|
46 |
-
set_logger(os.path.join(working_dir, "logs", f"graphgen_{unique_id}.log"), if_stream=False)
|
47 |
|
48 |
-
with open(args.config_file, "r", encoding=
|
49 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
data = json.load(f)
|
59 |
-
else:
|
60 |
-
raise ValueError(f"Invalid data type: {config['data_type']}")
|
61 |
-
|
62 |
-
synthesizer_llm_client = OpenAIModel(
|
63 |
-
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
64 |
-
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
65 |
-
base_url=os.getenv("SYNTHESIZER_BASE_URL")
|
66 |
-
)
|
67 |
-
trainee_llm_client = OpenAIModel(
|
68 |
-
model_name=os.getenv("TRAINEE_MODEL"),
|
69 |
-
api_key=os.getenv("TRAINEE_API_KEY"),
|
70 |
-
base_url=os.getenv("TRAINEE_BASE_URL")
|
71 |
-
)
|
72 |
-
|
73 |
-
traverse_strategy = TraverseStrategy(
|
74 |
-
**config['traverse_strategy']
|
75 |
)
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
trainee_llm_client=trainee_llm_client,
|
82 |
-
if_web_search=config['web_search'],
|
83 |
-
tokenizer_instance=Tokenizer(
|
84 |
-
model_name=config['tokenizer']
|
85 |
),
|
86 |
-
traverse_strategy=traverse_strategy
|
87 |
)
|
88 |
|
89 |
-
graph_gen
|
90 |
-
|
91 |
-
graph_gen.
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
|
|
|
|
|
96 |
|
97 |
-
path = os.path.join(working_dir, "data", "graphgen", str(unique_id), f"config-{unique_id}.yaml")
|
98 |
-
save_config(path, config)
|
99 |
|
100 |
-
if __name__ ==
|
101 |
main()
|
|
|
1 |
+
import argparse
|
2 |
import os
|
|
|
3 |
import time
|
|
|
4 |
from importlib.resources import files
|
5 |
+
|
6 |
import yaml
|
7 |
from dotenv import load_dotenv
|
8 |
|
9 |
from .graphgen import GraphGen
|
10 |
+
from .utils import logger, set_logger
|
|
|
11 |
|
12 |
sys_path = os.path.abspath(os.path.dirname(__file__))
|
13 |
|
14 |
load_dotenv()
|
15 |
|
16 |
+
|
17 |
def set_working_dir(folder):
|
18 |
os.makedirs(folder, exist_ok=True)
|
19 |
os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
|
20 |
os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
|
21 |
|
22 |
+
|
23 |
def save_config(config_path, global_config):
|
24 |
if not os.path.exists(os.path.dirname(config_path)):
|
25 |
os.makedirs(os.path.dirname(config_path))
|
26 |
+
with open(config_path, "w", encoding="utf-8") as config_file:
|
27 |
+
yaml.dump(
|
28 |
+
global_config, config_file, default_flow_style=False, allow_unicode=True
|
29 |
+
)
|
30 |
+
|
31 |
|
32 |
def main():
|
33 |
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument(
|
35 |
+
"--config_file",
|
36 |
+
help="Config parameters for GraphGen.",
|
37 |
+
default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
|
38 |
+
type=str,
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--output_dir",
|
42 |
+
help="Output directory for GraphGen.",
|
43 |
+
default=sys_path,
|
44 |
+
required=True,
|
45 |
+
type=str,
|
46 |
+
)
|
47 |
|
48 |
args = parser.parse_args()
|
49 |
|
50 |
working_dir = args.output_dir
|
51 |
set_working_dir(working_dir)
|
|
|
|
|
52 |
|
53 |
+
with open(args.config_file, "r", encoding="utf-8") as f:
|
54 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
55 |
|
56 |
+
output_data_type = config["output_data_type"]
|
57 |
+
unique_id = int(time.time())
|
58 |
+
set_logger(
|
59 |
+
os.path.join(
|
60 |
+
working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
|
61 |
+
),
|
62 |
+
if_stream=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
)
|
64 |
+
logger.info(
|
65 |
+
"GraphGen with unique ID %s logging to %s",
|
66 |
+
unique_id,
|
67 |
+
os.path.join(
|
68 |
+
working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
|
|
|
|
|
|
|
|
|
69 |
),
|
|
|
70 |
)
|
71 |
|
72 |
+
graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
|
73 |
+
|
74 |
+
graph_gen.insert()
|
75 |
+
|
76 |
+
if config["search"]["enabled"]:
|
77 |
+
graph_gen.search()
|
78 |
+
|
79 |
+
# Use pipeline according to the output data type
|
80 |
+
if output_data_type in ["atomic", "aggregated", "multi_hop"]:
|
81 |
+
if "quiz_and_judge_strategy" in config and config[
|
82 |
+
"quiz_and_judge_strategy"
|
83 |
+
].get("enabled", False):
|
84 |
+
graph_gen.quiz()
|
85 |
+
graph_gen.judge()
|
86 |
+
else:
|
87 |
+
logger.warning(
|
88 |
+
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
|
89 |
+
)
|
90 |
+
graph_gen.traverse_strategy.edge_sampling = "random"
|
91 |
+
graph_gen.traverse()
|
92 |
+
elif output_data_type == "cot":
|
93 |
+
graph_gen.generate_reasoning(method_params=config["method_params"])
|
94 |
+
else:
|
95 |
+
raise ValueError(f"Unsupported output data type: {output_data_type}")
|
96 |
|
97 |
+
output_path = os.path.join(working_dir, "data", "graphgen", str(unique_id))
|
98 |
+
save_config(os.path.join(output_path, f"config-{unique_id}.yaml"), config)
|
99 |
+
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
|
100 |
|
|
|
|
|
101 |
|
102 |
+
if __name__ == "__main__":
|
103 |
main()
|
graphgen/graphgen.py
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
-
# Adapt from https://github.com/HKUDS/LightRAG
|
2 |
-
|
3 |
import asyncio
|
4 |
import os
|
5 |
import time
|
6 |
from dataclasses import dataclass, field
|
7 |
-
from typing import List, Union, cast
|
8 |
|
9 |
import gradio as gr
|
10 |
from tqdm.asyncio import tqdm as tqdm_async
|
@@ -12,85 +10,124 @@ from tqdm.asyncio import tqdm as tqdm_async
|
|
12 |
from .models import (
|
13 |
Chunk,
|
14 |
JsonKVStorage,
|
|
|
15 |
NetworkXStorage,
|
16 |
OpenAIModel,
|
17 |
Tokenizer,
|
18 |
TraverseStrategy,
|
19 |
-
WikiSearch,
|
20 |
)
|
21 |
from .models.storage.base_storage import StorageNameSpace
|
22 |
from .operators import (
|
23 |
extract_kg,
|
|
|
24 |
judge_statement,
|
25 |
quiz,
|
26 |
-
|
27 |
-
skip_judge_statement,
|
28 |
traverse_graph_atomically,
|
29 |
traverse_graph_by_edge,
|
30 |
traverse_graph_for_multi_hop,
|
31 |
)
|
32 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
35 |
|
|
|
36 |
@dataclass
|
37 |
class GraphGen:
|
38 |
unique_id: int = int(time.time())
|
39 |
working_dir: str = os.path.join(sys_path, "cache")
|
40 |
-
|
41 |
-
# text chunking
|
42 |
-
chunk_size: int = 1024
|
43 |
-
chunk_overlap_size: int = 100
|
44 |
|
45 |
# llm
|
|
|
46 |
synthesizer_llm_client: OpenAIModel = None
|
47 |
trainee_llm_client: OpenAIModel = None
|
48 |
-
tokenizer_instance: Tokenizer = None
|
49 |
|
50 |
-
#
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
#
|
55 |
-
traverse_strategy: TraverseStrategy =
|
56 |
|
57 |
# webui
|
58 |
progress_bar: gr.Progress = None
|
59 |
|
60 |
def __post_init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
62 |
self.working_dir, namespace="full_docs"
|
63 |
)
|
64 |
self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
|
65 |
self.working_dir, namespace="text_chunks"
|
66 |
)
|
67 |
-
self.wiki_storage: JsonKVStorage = JsonKVStorage(
|
68 |
-
self.working_dir, namespace="wiki"
|
69 |
-
)
|
70 |
self.graph_storage: NetworkXStorage = NetworkXStorage(
|
71 |
self.working_dir, namespace="graph"
|
72 |
)
|
|
|
|
|
|
|
73 |
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
|
74 |
self.working_dir, namespace="rephrase"
|
75 |
)
|
76 |
-
self.qa_storage:
|
77 |
-
os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)),
|
|
|
78 |
)
|
79 |
|
80 |
-
async def async_split_chunks(
|
81 |
-
|
|
|
|
|
82 |
if len(data) == 0:
|
83 |
return {}
|
84 |
|
85 |
-
new_docs = {}
|
86 |
inserting_chunks = {}
|
87 |
if data_type == "raw":
|
88 |
assert isinstance(data, list) and isinstance(data[0], dict)
|
89 |
# compute hash for each document
|
90 |
new_docs = {
|
91 |
-
compute_content_hash(doc[
|
|
|
|
|
|
|
92 |
}
|
93 |
-
_add_doc_keys = await self.full_docs_storage.filter_keys(
|
|
|
|
|
94 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
95 |
if len(new_docs) == 0:
|
96 |
logger.warning("All docs are already in the storage")
|
@@ -100,63 +137,83 @@ class GraphGen:
|
|
100 |
cur_index = 1
|
101 |
doc_number = len(new_docs)
|
102 |
async for doc_key, doc in tqdm_async(
|
103 |
-
|
104 |
-
|
105 |
chunks = {
|
106 |
compute_content_hash(dp["content"], prefix="chunk-"): {
|
107 |
**dp,
|
108 |
-
|
109 |
-
}
|
110 |
-
|
|
|
|
|
111 |
}
|
112 |
inserting_chunks.update(chunks)
|
113 |
|
114 |
if self.progress_bar is not None:
|
115 |
-
self.progress_bar(
|
116 |
-
cur_index / doc_number, f"Chunking {doc_key}"
|
117 |
-
)
|
118 |
cur_index += 1
|
119 |
|
120 |
-
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
|
121 |
-
|
|
|
|
|
|
|
|
|
122 |
elif data_type == "chunked":
|
123 |
assert isinstance(data, list) and isinstance(data[0], list)
|
124 |
new_docs = {
|
125 |
-
compute_content_hash("".join(chunk[
|
126 |
-
|
|
|
|
|
|
|
127 |
}
|
128 |
-
_add_doc_keys = await self.full_docs_storage.filter_keys(
|
|
|
|
|
129 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
130 |
if len(new_docs) == 0:
|
131 |
logger.warning("All docs are already in the storage")
|
132 |
return {}
|
133 |
logger.info("[New Docs] inserting %d docs", len(new_docs))
|
134 |
-
async for doc in tqdm_async(
|
135 |
-
|
|
|
|
|
136 |
for chunk in doc:
|
137 |
-
chunk_key = compute_content_hash(chunk[
|
138 |
inserting_chunks[chunk_key] = {
|
139 |
**chunk,
|
140 |
-
|
141 |
}
|
142 |
-
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
await self.full_docs_storage.upsert(new_docs)
|
146 |
await self.text_chunks_storage.upsert(inserting_chunks)
|
147 |
|
148 |
return inserting_chunks
|
149 |
|
150 |
-
def insert(self
|
151 |
loop = create_event_loop()
|
152 |
-
loop.run_until_complete(self.async_insert(
|
153 |
|
154 |
-
async def async_insert(self
|
155 |
"""
|
156 |
-
|
157 |
insert chunks into the graph
|
158 |
"""
|
159 |
|
|
|
|
|
|
|
|
|
160 |
inserting_chunks = await self.async_split_chunks(data, data_type)
|
161 |
|
162 |
if len(inserting_chunks) == 0:
|
@@ -169,52 +226,96 @@ class GraphGen:
|
|
169 |
llm_client=self.synthesizer_llm_client,
|
170 |
kg_instance=self.graph_storage,
|
171 |
tokenizer_instance=self.tokenizer_instance,
|
172 |
-
chunks=[
|
173 |
-
|
|
|
|
|
174 |
)
|
175 |
if not _add_entities_and_relations:
|
176 |
logger.warning("No entities or relations extracted")
|
177 |
return
|
178 |
|
179 |
-
logger.info("[Wiki Search] is %s", 'enabled' if self.if_web_search else 'disabled')
|
180 |
-
if self.if_web_search:
|
181 |
-
logger.info("[Wiki Search]...")
|
182 |
-
_add_wiki_data = await search_wikipedia(
|
183 |
-
llm_client= self.synthesizer_llm_client,
|
184 |
-
wiki_search_client=self.wiki_client,
|
185 |
-
knowledge_graph_instance=_add_entities_and_relations
|
186 |
-
)
|
187 |
-
await self.wiki_storage.upsert(_add_wiki_data)
|
188 |
-
|
189 |
await self._insert_done()
|
190 |
|
191 |
async def _insert_done(self):
|
192 |
tasks = []
|
193 |
-
for storage_instance in [
|
194 |
-
|
|
|
|
|
|
|
|
|
195 |
if storage_instance is None:
|
196 |
continue
|
197 |
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
|
198 |
await asyncio.gather(*tasks)
|
199 |
|
200 |
-
def
|
201 |
loop = create_event_loop()
|
202 |
-
loop.run_until_complete(self.
|
203 |
|
204 |
-
async def
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
-
def
|
209 |
loop = create_event_loop()
|
210 |
-
loop.run_until_complete(self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
await _update_relations.index_done_callback()
|
219 |
|
220 |
def traverse(self):
|
@@ -222,26 +323,60 @@ class GraphGen:
|
|
222 |
loop.run_until_complete(self.async_traverse())
|
223 |
|
224 |
async def async_traverse(self):
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
else:
|
244 |
-
raise ValueError(f"Unknown qa_form: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
await self.qa_storage.upsert(results)
|
246 |
await self.qa_storage.index_done_callback()
|
247 |
|
@@ -252,7 +387,7 @@ class GraphGen:
|
|
252 |
async def async_clear(self):
|
253 |
await self.full_docs_storage.drop()
|
254 |
await self.text_chunks_storage.drop()
|
255 |
-
await self.
|
256 |
await self.graph_storage.clear()
|
257 |
await self.rephrase_storage.drop()
|
258 |
await self.qa_storage.drop()
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
import time
|
4 |
from dataclasses import dataclass, field
|
5 |
+
from typing import Dict, List, Union, cast
|
6 |
|
7 |
import gradio as gr
|
8 |
from tqdm.asyncio import tqdm as tqdm_async
|
|
|
10 |
from .models import (
|
11 |
Chunk,
|
12 |
JsonKVStorage,
|
13 |
+
JsonListStorage,
|
14 |
NetworkXStorage,
|
15 |
OpenAIModel,
|
16 |
Tokenizer,
|
17 |
TraverseStrategy,
|
|
|
18 |
)
|
19 |
from .models.storage.base_storage import StorageNameSpace
|
20 |
from .operators import (
|
21 |
extract_kg,
|
22 |
+
generate_cot,
|
23 |
judge_statement,
|
24 |
quiz,
|
25 |
+
search_all,
|
|
|
26 |
traverse_graph_atomically,
|
27 |
traverse_graph_by_edge,
|
28 |
traverse_graph_for_multi_hop,
|
29 |
)
|
30 |
+
from .utils import (
|
31 |
+
compute_content_hash,
|
32 |
+
create_event_loop,
|
33 |
+
format_generation_results,
|
34 |
+
logger,
|
35 |
+
read_file,
|
36 |
+
)
|
37 |
|
38 |
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
39 |
|
40 |
+
|
41 |
@dataclass
|
42 |
class GraphGen:
|
43 |
unique_id: int = int(time.time())
|
44 |
working_dir: str = os.path.join(sys_path, "cache")
|
45 |
+
config: Dict = field(default_factory=dict)
|
|
|
|
|
|
|
46 |
|
47 |
# llm
|
48 |
+
tokenizer_instance: Tokenizer = None
|
49 |
synthesizer_llm_client: OpenAIModel = None
|
50 |
trainee_llm_client: OpenAIModel = None
|
|
|
51 |
|
52 |
+
# text chunking
|
53 |
+
# TODO: make it configurable
|
54 |
+
chunk_size: int = 1024
|
55 |
+
chunk_overlap_size: int = 100
|
56 |
+
|
57 |
+
# search
|
58 |
+
search_config: dict = field(
|
59 |
+
default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
|
60 |
+
)
|
61 |
|
62 |
+
# traversal
|
63 |
+
traverse_strategy: TraverseStrategy = None
|
64 |
|
65 |
# webui
|
66 |
progress_bar: gr.Progress = None
|
67 |
|
68 |
def __post_init__(self):
|
69 |
+
self.tokenizer_instance: Tokenizer = Tokenizer(
|
70 |
+
model_name=self.config["tokenizer"]
|
71 |
+
)
|
72 |
+
self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
|
73 |
+
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
74 |
+
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
75 |
+
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
|
76 |
+
tokenizer_instance=self.tokenizer_instance,
|
77 |
+
)
|
78 |
+
self.trainee_llm_client: OpenAIModel = OpenAIModel(
|
79 |
+
model_name=os.getenv("TRAINEE_MODEL"),
|
80 |
+
api_key=os.getenv("TRAINEE_API_KEY"),
|
81 |
+
base_url=os.getenv("TRAINEE_BASE_URL"),
|
82 |
+
tokenizer_instance=self.tokenizer_instance,
|
83 |
+
)
|
84 |
+
self.search_config = self.config["search"]
|
85 |
+
|
86 |
+
if "traverse_strategy" in self.config:
|
87 |
+
self.traverse_strategy = TraverseStrategy(
|
88 |
+
**self.config["traverse_strategy"]
|
89 |
+
)
|
90 |
+
|
91 |
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
92 |
self.working_dir, namespace="full_docs"
|
93 |
)
|
94 |
self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
|
95 |
self.working_dir, namespace="text_chunks"
|
96 |
)
|
|
|
|
|
|
|
97 |
self.graph_storage: NetworkXStorage = NetworkXStorage(
|
98 |
self.working_dir, namespace="graph"
|
99 |
)
|
100 |
+
self.search_storage: JsonKVStorage = JsonKVStorage(
|
101 |
+
self.working_dir, namespace="search"
|
102 |
+
)
|
103 |
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
|
104 |
self.working_dir, namespace="rephrase"
|
105 |
)
|
106 |
+
self.qa_storage: JsonListStorage = JsonListStorage(
|
107 |
+
os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)),
|
108 |
+
namespace=f"qa-{self.unique_id}",
|
109 |
)
|
110 |
|
111 |
+
async def async_split_chunks(
|
112 |
+
self, data: List[Union[List, Dict]], data_type: str
|
113 |
+
) -> dict:
|
114 |
+
# TODO: configurable whether to use coreference resolution
|
115 |
if len(data) == 0:
|
116 |
return {}
|
117 |
|
|
|
118 |
inserting_chunks = {}
|
119 |
if data_type == "raw":
|
120 |
assert isinstance(data, list) and isinstance(data[0], dict)
|
121 |
# compute hash for each document
|
122 |
new_docs = {
|
123 |
+
compute_content_hash(doc["content"], prefix="doc-"): {
|
124 |
+
"content": doc["content"]
|
125 |
+
}
|
126 |
+
for doc in data
|
127 |
}
|
128 |
+
_add_doc_keys = await self.full_docs_storage.filter_keys(
|
129 |
+
list(new_docs.keys())
|
130 |
+
)
|
131 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
132 |
if len(new_docs) == 0:
|
133 |
logger.warning("All docs are already in the storage")
|
|
|
137 |
cur_index = 1
|
138 |
doc_number = len(new_docs)
|
139 |
async for doc_key, doc in tqdm_async(
|
140 |
+
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
|
141 |
+
):
|
142 |
chunks = {
|
143 |
compute_content_hash(dp["content"], prefix="chunk-"): {
|
144 |
**dp,
|
145 |
+
"full_doc_id": doc_key,
|
146 |
+
}
|
147 |
+
for dp in self.tokenizer_instance.chunk_by_token_size(
|
148 |
+
doc["content"], self.chunk_overlap_size, self.chunk_size
|
149 |
+
)
|
150 |
}
|
151 |
inserting_chunks.update(chunks)
|
152 |
|
153 |
if self.progress_bar is not None:
|
154 |
+
self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
|
|
|
|
|
155 |
cur_index += 1
|
156 |
|
157 |
+
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
|
158 |
+
list(inserting_chunks.keys())
|
159 |
+
)
|
160 |
+
inserting_chunks = {
|
161 |
+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
162 |
+
}
|
163 |
elif data_type == "chunked":
|
164 |
assert isinstance(data, list) and isinstance(data[0], list)
|
165 |
new_docs = {
|
166 |
+
compute_content_hash("".join(chunk["content"]), prefix="doc-"): {
|
167 |
+
"content": "".join(chunk["content"])
|
168 |
+
}
|
169 |
+
for doc in data
|
170 |
+
for chunk in doc
|
171 |
}
|
172 |
+
_add_doc_keys = await self.full_docs_storage.filter_keys(
|
173 |
+
list(new_docs.keys())
|
174 |
+
)
|
175 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
176 |
if len(new_docs) == 0:
|
177 |
logger.warning("All docs are already in the storage")
|
178 |
return {}
|
179 |
logger.info("[New Docs] inserting %d docs", len(new_docs))
|
180 |
+
async for doc in tqdm_async(
|
181 |
+
data, desc="[1/4]Chunking documents", unit="doc"
|
182 |
+
):
|
183 |
+
doc_str = "".join([chunk["content"] for chunk in doc])
|
184 |
for chunk in doc:
|
185 |
+
chunk_key = compute_content_hash(chunk["content"], prefix="chunk-")
|
186 |
inserting_chunks[chunk_key] = {
|
187 |
**chunk,
|
188 |
+
"full_doc_id": compute_content_hash(doc_str, prefix="doc-"),
|
189 |
}
|
190 |
+
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
|
191 |
+
list(inserting_chunks.keys())
|
192 |
+
)
|
193 |
+
inserting_chunks = {
|
194 |
+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
195 |
+
}
|
196 |
+
else:
|
197 |
+
raise ValueError(f"Unknown data type: {data_type}")
|
198 |
|
199 |
await self.full_docs_storage.upsert(new_docs)
|
200 |
await self.text_chunks_storage.upsert(inserting_chunks)
|
201 |
|
202 |
return inserting_chunks
|
203 |
|
204 |
+
def insert(self):
|
205 |
loop = create_event_loop()
|
206 |
+
loop.run_until_complete(self.async_insert())
|
207 |
|
208 |
+
async def async_insert(self):
|
209 |
"""
|
|
|
210 |
insert chunks into the graph
|
211 |
"""
|
212 |
|
213 |
+
input_file = self.config["input_file"]
|
214 |
+
data_type = self.config["input_data_type"]
|
215 |
+
data = read_file(input_file)
|
216 |
+
|
217 |
inserting_chunks = await self.async_split_chunks(data, data_type)
|
218 |
|
219 |
if len(inserting_chunks) == 0:
|
|
|
226 |
llm_client=self.synthesizer_llm_client,
|
227 |
kg_instance=self.graph_storage,
|
228 |
tokenizer_instance=self.tokenizer_instance,
|
229 |
+
chunks=[
|
230 |
+
Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
|
231 |
+
],
|
232 |
+
progress_bar=self.progress_bar,
|
233 |
)
|
234 |
if not _add_entities_and_relations:
|
235 |
logger.warning("No entities or relations extracted")
|
236 |
return
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
await self._insert_done()
|
239 |
|
240 |
async def _insert_done(self):
|
241 |
tasks = []
|
242 |
+
for storage_instance in [
|
243 |
+
self.full_docs_storage,
|
244 |
+
self.text_chunks_storage,
|
245 |
+
self.graph_storage,
|
246 |
+
self.search_storage,
|
247 |
+
]:
|
248 |
if storage_instance is None:
|
249 |
continue
|
250 |
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
|
251 |
await asyncio.gather(*tasks)
|
252 |
|
253 |
+
def search(self):
|
254 |
loop = create_event_loop()
|
255 |
+
loop.run_until_complete(self.async_search())
|
256 |
|
257 |
+
async def async_search(self):
|
258 |
+
logger.info(
|
259 |
+
"Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
|
260 |
+
)
|
261 |
+
if self.search_config["enabled"]:
|
262 |
+
logger.info(
|
263 |
+
"[Search] %s ...", ", ".join(self.search_config["search_types"])
|
264 |
+
)
|
265 |
+
all_nodes = await self.graph_storage.get_all_nodes()
|
266 |
+
all_nodes_names = [node[0] for node in all_nodes]
|
267 |
+
new_search_entities = await self.full_docs_storage.filter_keys(
|
268 |
+
all_nodes_names
|
269 |
+
)
|
270 |
+
logger.info(
|
271 |
+
"[Search] Found %d entities to search", len(new_search_entities)
|
272 |
+
)
|
273 |
+
_add_search_data = await search_all(
|
274 |
+
search_types=self.search_config["search_types"],
|
275 |
+
search_entities=new_search_entities,
|
276 |
+
)
|
277 |
+
if _add_search_data:
|
278 |
+
await self.search_storage.upsert(_add_search_data)
|
279 |
+
logger.info("[Search] %d entities searched", len(_add_search_data))
|
280 |
+
|
281 |
+
# Format search results for inserting
|
282 |
+
search_results = []
|
283 |
+
for _, search_data in _add_search_data.items():
|
284 |
+
search_results.extend(
|
285 |
+
[
|
286 |
+
{"content": search_data[key]}
|
287 |
+
for key in list(search_data.keys())
|
288 |
+
]
|
289 |
+
)
|
290 |
+
# TODO: fix insert after search
|
291 |
+
await self.async_insert()
|
292 |
|
293 |
+
def quiz(self):
|
294 |
loop = create_event_loop()
|
295 |
+
loop.run_until_complete(self.async_quiz())
|
296 |
+
|
297 |
+
async def async_quiz(self):
|
298 |
+
max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
|
299 |
+
await quiz(
|
300 |
+
self.synthesizer_llm_client,
|
301 |
+
self.graph_storage,
|
302 |
+
self.rephrase_storage,
|
303 |
+
max_samples,
|
304 |
+
)
|
305 |
+
await self.rephrase_storage.index_done_callback()
|
306 |
|
307 |
+
def judge(self):
|
308 |
+
loop = create_event_loop()
|
309 |
+
loop.run_until_complete(self.async_judge())
|
310 |
+
|
311 |
+
async def async_judge(self):
|
312 |
+
re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
|
313 |
+
_update_relations = await judge_statement(
|
314 |
+
self.trainee_llm_client,
|
315 |
+
self.graph_storage,
|
316 |
+
self.rephrase_storage,
|
317 |
+
re_judge,
|
318 |
+
)
|
319 |
await _update_relations.index_done_callback()
|
320 |
|
321 |
def traverse(self):
|
|
|
323 |
loop.run_until_complete(self.async_traverse())
|
324 |
|
325 |
async def async_traverse(self):
|
326 |
+
output_data_type = self.config["output_data_type"]
|
327 |
+
|
328 |
+
if output_data_type == "atomic":
|
329 |
+
results = await traverse_graph_atomically(
|
330 |
+
self.synthesizer_llm_client,
|
331 |
+
self.tokenizer_instance,
|
332 |
+
self.graph_storage,
|
333 |
+
self.traverse_strategy,
|
334 |
+
self.text_chunks_storage,
|
335 |
+
self.progress_bar,
|
336 |
+
)
|
337 |
+
elif output_data_type == "multi_hop":
|
338 |
+
results = await traverse_graph_for_multi_hop(
|
339 |
+
self.synthesizer_llm_client,
|
340 |
+
self.tokenizer_instance,
|
341 |
+
self.graph_storage,
|
342 |
+
self.traverse_strategy,
|
343 |
+
self.text_chunks_storage,
|
344 |
+
self.progress_bar,
|
345 |
+
)
|
346 |
+
elif output_data_type == "aggregated":
|
347 |
+
results = await traverse_graph_by_edge(
|
348 |
+
self.synthesizer_llm_client,
|
349 |
+
self.tokenizer_instance,
|
350 |
+
self.graph_storage,
|
351 |
+
self.traverse_strategy,
|
352 |
+
self.text_chunks_storage,
|
353 |
+
self.progress_bar,
|
354 |
+
)
|
355 |
else:
|
356 |
+
raise ValueError(f"Unknown qa_form: {output_data_type}")
|
357 |
+
|
358 |
+
results = format_generation_results(
|
359 |
+
results, output_data_format=self.config["output_data_format"]
|
360 |
+
)
|
361 |
+
|
362 |
+
await self.qa_storage.upsert(results)
|
363 |
+
await self.qa_storage.index_done_callback()
|
364 |
+
|
365 |
+
def generate_reasoning(self, method_params):
|
366 |
+
loop = create_event_loop()
|
367 |
+
loop.run_until_complete(self.async_generate_reasoning(method_params))
|
368 |
+
|
369 |
+
async def async_generate_reasoning(self, method_params):
|
370 |
+
results = await generate_cot(
|
371 |
+
self.graph_storage,
|
372 |
+
self.synthesizer_llm_client,
|
373 |
+
method_params=method_params,
|
374 |
+
)
|
375 |
+
|
376 |
+
results = format_generation_results(
|
377 |
+
results, output_data_format=self.config["output_data_format"]
|
378 |
+
)
|
379 |
+
|
380 |
await self.qa_storage.upsert(results)
|
381 |
await self.qa_storage.index_done_callback()
|
382 |
|
|
|
387 |
async def async_clear(self):
|
388 |
await self.full_docs_storage.drop()
|
389 |
await self.text_chunks_storage.drop()
|
390 |
+
await self.search_storage.drop()
|
391 |
await self.graph_storage.clear()
|
392 |
await self.rephrase_storage.drop()
|
393 |
await self.qa_storage.drop()
|
graphgen/models/__init__.py
CHANGED
@@ -1,22 +1,20 @@
|
|
1 |
-
from .
|
2 |
-
from .text.text_pair import TextPair
|
3 |
-
|
4 |
-
from .llm.topk_token_model import Token, TopkTokenModel
|
5 |
-
from .llm.openai_model import OpenAIModel
|
6 |
-
from .llm.tokenizer import Tokenizer
|
7 |
-
|
8 |
-
from .storage.networkx_storage import NetworkXStorage
|
9 |
-
from .storage.json_storage import JsonKVStorage
|
10 |
-
|
11 |
-
from .search.wiki_search import WikiSearch
|
12 |
-
|
13 |
from .evaluate.length_evaluator import LengthEvaluator
|
14 |
from .evaluate.mtld_evaluator import MTLDEvaluator
|
15 |
from .evaluate.reward_evaluator import RewardEvaluator
|
16 |
from .evaluate.uni_evaluator import UniEvaluator
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
from .strategy.travserse_strategy import TraverseStrategy
|
19 |
-
|
|
|
20 |
|
21 |
__all__ = [
|
22 |
# llm models
|
@@ -28,8 +26,12 @@ __all__ = [
|
|
28 |
"Chunk",
|
29 |
"NetworkXStorage",
|
30 |
"JsonKVStorage",
|
|
|
31 |
# search models
|
32 |
"WikiSearch",
|
|
|
|
|
|
|
33 |
# evaluate models
|
34 |
"TextPair",
|
35 |
"LengthEvaluator",
|
@@ -38,4 +40,6 @@ __all__ = [
|
|
38 |
"UniEvaluator",
|
39 |
# strategy models
|
40 |
"TraverseStrategy",
|
|
|
|
|
41 |
]
|
|
|
1 |
+
from .community.community_detector import CommunityDetector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from .evaluate.length_evaluator import LengthEvaluator
|
3 |
from .evaluate.mtld_evaluator import MTLDEvaluator
|
4 |
from .evaluate.reward_evaluator import RewardEvaluator
|
5 |
from .evaluate.uni_evaluator import UniEvaluator
|
6 |
+
from .llm.openai_model import OpenAIModel
|
7 |
+
from .llm.tokenizer import Tokenizer
|
8 |
+
from .llm.topk_token_model import Token, TopkTokenModel
|
9 |
+
from .search.db.uniprot_search import UniProtSearch
|
10 |
+
from .search.kg.wiki_search import WikiSearch
|
11 |
+
from .search.web.bing_search import BingSearch
|
12 |
+
from .search.web.google_search import GoogleSearch
|
13 |
+
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
14 |
+
from .storage.networkx_storage import NetworkXStorage
|
15 |
from .strategy.travserse_strategy import TraverseStrategy
|
16 |
+
from .text.chunk import Chunk
|
17 |
+
from .text.text_pair import TextPair
|
18 |
|
19 |
__all__ = [
|
20 |
# llm models
|
|
|
26 |
"Chunk",
|
27 |
"NetworkXStorage",
|
28 |
"JsonKVStorage",
|
29 |
+
"JsonListStorage",
|
30 |
# search models
|
31 |
"WikiSearch",
|
32 |
+
"GoogleSearch",
|
33 |
+
"BingSearch",
|
34 |
+
"UniProtSearch",
|
35 |
# evaluate models
|
36 |
"TextPair",
|
37 |
"LengthEvaluator",
|
|
|
40 |
"UniEvaluator",
|
41 |
# strategy models
|
42 |
"TraverseStrategy",
|
43 |
+
# community models
|
44 |
+
"CommunityDetector",
|
45 |
]
|
graphgen/models/community/__init__.py
ADDED
File without changes
|
graphgen/models/community/community_detector.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
from graphgen.models.storage.networkx_storage import NetworkXStorage
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class CommunityDetector:
|
10 |
+
"""Class for community detection algorithms."""
|
11 |
+
|
12 |
+
graph_storage: NetworkXStorage = None
|
13 |
+
method: str = "leiden"
|
14 |
+
method_params: Dict[str, Any] = None
|
15 |
+
|
16 |
+
async def detect_communities(self) -> Dict[str, int]:
|
17 |
+
if self.method == "leiden":
|
18 |
+
return await self._leiden_communities(**self.method_params or {})
|
19 |
+
raise ValueError(f"Unknown community detection method: {self.method}")
|
20 |
+
|
21 |
+
async def get_graph(self):
|
22 |
+
return await self.graph_storage.get_graph()
|
23 |
+
|
24 |
+
async def _leiden_communities(
|
25 |
+
self, max_size: int = None, **kwargs
|
26 |
+
) -> Dict[str, int]:
|
27 |
+
"""
|
28 |
+
Detect communities using the Leiden algorithm.
|
29 |
+
If max_size is given, any community larger than max_size will be split
|
30 |
+
into smaller sub-communities each having at most max_size nodes.
|
31 |
+
"""
|
32 |
+
import igraph as ig
|
33 |
+
import networkx as nx
|
34 |
+
from leidenalg import ModularityVertexPartition, find_partition
|
35 |
+
|
36 |
+
graph = await self.get_graph()
|
37 |
+
graph.remove_nodes_from(list(nx.isolates(graph)))
|
38 |
+
|
39 |
+
ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
|
40 |
+
|
41 |
+
random_seed = kwargs.get("random_seed", 42)
|
42 |
+
use_lcc = kwargs.get("use_lcc", False)
|
43 |
+
|
44 |
+
communities: Dict[str, int] = {}
|
45 |
+
if use_lcc:
|
46 |
+
lcc = ig_graph.components().giant()
|
47 |
+
partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
|
48 |
+
for part, cluster in enumerate(partition):
|
49 |
+
for v in cluster:
|
50 |
+
communities[lcc.vs[v]["name"]] = part
|
51 |
+
else:
|
52 |
+
offset = 0
|
53 |
+
for component in ig_graph.components():
|
54 |
+
subgraph = ig_graph.induced_subgraph(component)
|
55 |
+
partition = find_partition(
|
56 |
+
subgraph, ModularityVertexPartition, seed=random_seed
|
57 |
+
)
|
58 |
+
for part, cluster in enumerate(partition):
|
59 |
+
for v in cluster:
|
60 |
+
original_node = subgraph.vs[v]["name"]
|
61 |
+
communities[original_node] = part + offset
|
62 |
+
offset += len(partition)
|
63 |
+
|
64 |
+
# split large communities if max_size is specified
|
65 |
+
if max_size is None or max_size <= 0:
|
66 |
+
return communities
|
67 |
+
|
68 |
+
return await self._split_communities(communities, max_size)
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
async def _split_communities(
|
72 |
+
communities: Dict[str, int], max_size: int
|
73 |
+
) -> Dict[str, int]:
|
74 |
+
"""
|
75 |
+
Split communities larger than max_size into smaller sub-communities.
|
76 |
+
"""
|
77 |
+
cid2nodes: Dict[int, List[str]] = defaultdict(list)
|
78 |
+
for node, cid in communities.items():
|
79 |
+
cid2nodes[cid].append(node)
|
80 |
+
|
81 |
+
new_communities: Dict[str, int] = {}
|
82 |
+
new_cid = 0
|
83 |
+
for cid, nodes in cid2nodes.items():
|
84 |
+
if len(nodes) <= max_size:
|
85 |
+
for n in nodes:
|
86 |
+
new_communities[n] = new_cid
|
87 |
+
new_cid += 1
|
88 |
+
else:
|
89 |
+
for start in range(0, len(nodes), max_size):
|
90 |
+
sub = nodes[start : start + max_size]
|
91 |
+
for n in sub:
|
92 |
+
new_communities[n] = new_cid
|
93 |
+
new_cid += 1
|
94 |
+
|
95 |
+
return new_communities
|
graphgen/models/llm/openai_model.py
CHANGED
@@ -1,18 +1,21 @@
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass, field
|
3 |
-
from typing import
|
|
|
4 |
import openai
|
5 |
-
from openai import
|
6 |
from tenacity import (
|
7 |
retry,
|
|
|
8 |
stop_after_attempt,
|
9 |
wait_exponential,
|
10 |
-
retry_if_exception_type,
|
11 |
)
|
12 |
|
13 |
-
from graphgen.models.llm.topk_token_model import TopkTokenModel, Token
|
14 |
-
from graphgen.models.llm.tokenizer import Tokenizer
|
15 |
from graphgen.models.llm.limitter import RPM, TPM
|
|
|
|
|
|
|
16 |
|
17 |
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
|
18 |
token_logprobs = response.choices[0].logprobs.content
|
@@ -20,13 +23,23 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
|
|
20 |
for token_prob in token_logprobs:
|
21 |
prob = math.exp(token_prob.logprob)
|
22 |
candidate_tokens = [
|
23 |
-
Token(t.token, math.exp(t.logprob))
|
24 |
-
for t in token_prob.top_logprobs
|
25 |
]
|
26 |
token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
|
27 |
tokens.append(token)
|
28 |
return tokens
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
@dataclass
|
31 |
class OpenAIModel(TopkTokenModel):
|
32 |
model_name: str = "gpt-4o-mini"
|
@@ -42,12 +55,13 @@ class OpenAIModel(TopkTokenModel):
|
|
42 |
rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
|
43 |
tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
|
44 |
|
|
|
45 |
|
46 |
def __post_init__(self):
|
47 |
assert self.api_key is not None, "Please provide api key to access openai api."
|
48 |
-
|
49 |
-
self.api_key
|
50 |
-
|
51 |
|
52 |
def _pre_generate(self, text: str, history: List[str]) -> Dict:
|
53 |
kwargs = {
|
@@ -69,16 +83,19 @@ class OpenAIModel(TopkTokenModel):
|
|
69 |
assert len(history) % 2 == 0, "History should have even number of elements."
|
70 |
messages = history + messages
|
71 |
|
72 |
-
kwargs[
|
73 |
return kwargs
|
74 |
|
75 |
-
|
76 |
@retry(
|
77 |
stop=stop_after_attempt(5),
|
78 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
79 |
-
retry=retry_if_exception_type(
|
|
|
|
|
80 |
)
|
81 |
-
async def generate_topk_per_token(
|
|
|
|
|
82 |
kwargs = self._pre_generate(text, history)
|
83 |
if self.topk_per_token > 0:
|
84 |
kwargs["logprobs"] = True
|
@@ -87,9 +104,8 @@ class OpenAIModel(TopkTokenModel):
|
|
87 |
# Limit max_tokens to 1 to avoid long completions
|
88 |
kwargs["max_tokens"] = 1
|
89 |
|
90 |
-
completion = await self.client.chat.completions.create(
|
91 |
-
model=self.model_name,
|
92 |
-
**kwargs
|
93 |
)
|
94 |
|
95 |
tokens = get_top_response_tokens(completion)
|
@@ -99,32 +115,41 @@ class OpenAIModel(TopkTokenModel):
|
|
99 |
@retry(
|
100 |
stop=stop_after_attempt(5),
|
101 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
102 |
-
retry=retry_if_exception_type(
|
|
|
|
|
103 |
)
|
104 |
-
async def generate_answer(
|
|
|
|
|
105 |
kwargs = self._pre_generate(text, history)
|
106 |
kwargs["temperature"] = temperature
|
107 |
|
108 |
prompt_tokens = 0
|
109 |
-
for message in kwargs[
|
110 |
-
prompt_tokens += len(
|
111 |
-
|
|
|
|
|
112 |
|
113 |
if self.request_limit:
|
114 |
await self.rpm.wait(silent=True)
|
115 |
await self.tpm.wait(estimated_tokens, silent=True)
|
116 |
|
117 |
-
completion = await self.client.chat.completions.create(
|
118 |
-
model=self.model_name,
|
119 |
-
**kwargs
|
120 |
)
|
121 |
if hasattr(completion, "usage"):
|
122 |
-
self.token_usage.append(
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
130 |
raise NotImplementedError
|
|
|
1 |
import math
|
2 |
+
import re
|
3 |
from dataclasses import dataclass, field
|
4 |
+
from typing import Dict, List, Optional
|
5 |
+
|
6 |
import openai
|
7 |
+
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
|
8 |
from tenacity import (
|
9 |
retry,
|
10 |
+
retry_if_exception_type,
|
11 |
stop_after_attempt,
|
12 |
wait_exponential,
|
|
|
13 |
)
|
14 |
|
|
|
|
|
15 |
from graphgen.models.llm.limitter import RPM, TPM
|
16 |
+
from graphgen.models.llm.tokenizer import Tokenizer
|
17 |
+
from graphgen.models.llm.topk_token_model import Token, TopkTokenModel
|
18 |
+
|
19 |
|
20 |
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
|
21 |
token_logprobs = response.choices[0].logprobs.content
|
|
|
23 |
for token_prob in token_logprobs:
|
24 |
prob = math.exp(token_prob.logprob)
|
25 |
candidate_tokens = [
|
26 |
+
Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs
|
|
|
27 |
]
|
28 |
token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
|
29 |
tokens.append(token)
|
30 |
return tokens
|
31 |
|
32 |
+
|
33 |
+
def filter_think_tags(text: str) -> str:
|
34 |
+
"""
|
35 |
+
Remove <think> tags from the text.
|
36 |
+
If the text contains <think> and </think>, it removes everything between them and the tags themselves.
|
37 |
+
"""
|
38 |
+
think_pattern = re.compile(r"<think>.*?</think>", re.DOTALL)
|
39 |
+
filtered_text = think_pattern.sub("", text).strip()
|
40 |
+
return filtered_text if filtered_text else text.strip()
|
41 |
+
|
42 |
+
|
43 |
@dataclass
|
44 |
class OpenAIModel(TopkTokenModel):
|
45 |
model_name: str = "gpt-4o-mini"
|
|
|
55 |
rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
|
56 |
tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
|
57 |
|
58 |
+
tokenizer_instance: Tokenizer = field(default_factory=Tokenizer)
|
59 |
|
60 |
def __post_init__(self):
|
61 |
assert self.api_key is not None, "Please provide api key to access openai api."
|
62 |
+
self.client = AsyncOpenAI(
|
63 |
+
api_key=self.api_key or "dummy", base_url=self.base_url
|
64 |
+
)
|
65 |
|
66 |
def _pre_generate(self, text: str, history: List[str]) -> Dict:
|
67 |
kwargs = {
|
|
|
83 |
assert len(history) % 2 == 0, "History should have even number of elements."
|
84 |
messages = history + messages
|
85 |
|
86 |
+
kwargs["messages"] = messages
|
87 |
return kwargs
|
88 |
|
|
|
89 |
@retry(
|
90 |
stop=stop_after_attempt(5),
|
91 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
92 |
+
retry=retry_if_exception_type(
|
93 |
+
(RateLimitError, APIConnectionError, APITimeoutError)
|
94 |
+
),
|
95 |
)
|
96 |
+
async def generate_topk_per_token(
|
97 |
+
self, text: str, history: Optional[List[str]] = None
|
98 |
+
) -> List[Token]:
|
99 |
kwargs = self._pre_generate(text, history)
|
100 |
if self.topk_per_token > 0:
|
101 |
kwargs["logprobs"] = True
|
|
|
104 |
# Limit max_tokens to 1 to avoid long completions
|
105 |
kwargs["max_tokens"] = 1
|
106 |
|
107 |
+
completion = await self.client.chat.completions.create( # pylint: disable=E1125
|
108 |
+
model=self.model_name, **kwargs
|
|
|
109 |
)
|
110 |
|
111 |
tokens = get_top_response_tokens(completion)
|
|
|
115 |
@retry(
|
116 |
stop=stop_after_attempt(5),
|
117 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
118 |
+
retry=retry_if_exception_type(
|
119 |
+
(RateLimitError, APIConnectionError, APITimeoutError)
|
120 |
+
),
|
121 |
)
|
122 |
+
async def generate_answer(
|
123 |
+
self, text: str, history: Optional[List[str]] = None, temperature: int = 0
|
124 |
+
) -> str:
|
125 |
kwargs = self._pre_generate(text, history)
|
126 |
kwargs["temperature"] = temperature
|
127 |
|
128 |
prompt_tokens = 0
|
129 |
+
for message in kwargs["messages"]:
|
130 |
+
prompt_tokens += len(
|
131 |
+
self.tokenizer_instance.encode_string(message["content"])
|
132 |
+
)
|
133 |
+
estimated_tokens = prompt_tokens + kwargs["max_tokens"]
|
134 |
|
135 |
if self.request_limit:
|
136 |
await self.rpm.wait(silent=True)
|
137 |
await self.tpm.wait(estimated_tokens, silent=True)
|
138 |
|
139 |
+
completion = await self.client.chat.completions.create( # pylint: disable=E1125
|
140 |
+
model=self.model_name, **kwargs
|
|
|
141 |
)
|
142 |
if hasattr(completion, "usage"):
|
143 |
+
self.token_usage.append(
|
144 |
+
{
|
145 |
+
"prompt_tokens": completion.usage.prompt_tokens,
|
146 |
+
"completion_tokens": completion.usage.completion_tokens,
|
147 |
+
"total_tokens": completion.usage.total_tokens,
|
148 |
+
}
|
149 |
+
)
|
150 |
+
return filter_think_tags(completion.choices[0].message.content)
|
151 |
+
|
152 |
+
async def generate_inputs_prob(
|
153 |
+
self, text: str, history: Optional[List[str]] = None
|
154 |
+
) -> List[Token]:
|
155 |
raise NotImplementedError
|
graphgen/models/search/db/__init__.py
ADDED
File without changes
|
graphgen/models/search/db/uniprot_search.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import requests
|
4 |
+
from fastapi import HTTPException
|
5 |
+
|
6 |
+
from graphgen.utils import logger
|
7 |
+
|
8 |
+
UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class UniProtSearch:
|
13 |
+
"""
|
14 |
+
UniProt Search client to search with UniProt.
|
15 |
+
1) Get the protein by accession number.
|
16 |
+
2) Search with keywords or protein names.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def get_entry(self, accession: str) -> dict:
|
20 |
+
"""
|
21 |
+
Get the UniProt entry by accession number(e.g., P04637).
|
22 |
+
"""
|
23 |
+
url = f"{UNIPROT_BASE}/{accession}.json"
|
24 |
+
return self._safe_get(url).json()
|
25 |
+
|
26 |
+
def search(
|
27 |
+
self,
|
28 |
+
query: str,
|
29 |
+
*,
|
30 |
+
size: int = 10,
|
31 |
+
cursor: str = None,
|
32 |
+
fields: list[str] = None,
|
33 |
+
) -> dict:
|
34 |
+
"""
|
35 |
+
Search UniProt with a query string.
|
36 |
+
:param query: The search query.
|
37 |
+
:param size: The number of results to return.
|
38 |
+
:param cursor: The cursor for pagination.
|
39 |
+
:param fields: The fields to return in the response.
|
40 |
+
:return: A dictionary containing the search results.
|
41 |
+
"""
|
42 |
+
params = {
|
43 |
+
"query": query,
|
44 |
+
"size": size,
|
45 |
+
}
|
46 |
+
if cursor:
|
47 |
+
params["cursor"] = cursor
|
48 |
+
if fields:
|
49 |
+
params["fields"] = ",".join(fields)
|
50 |
+
url = UNIPROT_BASE
|
51 |
+
return self._safe_get(url, params=params).json()
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def _safe_get(url: str, params: dict = None) -> requests.Response:
|
55 |
+
r = requests.get(
|
56 |
+
url,
|
57 |
+
params=params,
|
58 |
+
headers={"Accept": "application/json"},
|
59 |
+
timeout=10,
|
60 |
+
)
|
61 |
+
if not r.ok:
|
62 |
+
logger.error("Search engine error: %s", r.text)
|
63 |
+
raise HTTPException(r.status_code, "Search engine error.")
|
64 |
+
return r
|
graphgen/models/search/kg/__init__.py
ADDED
File without changes
|
graphgen/models/search/{wiki_search.py → kg/wiki_search.py}
RENAMED
@@ -1,8 +1,9 @@
|
|
1 |
-
from typing import List, Union
|
2 |
from dataclasses import dataclass
|
|
|
3 |
|
4 |
import wikipedia
|
5 |
from wikipedia import set_lang
|
|
|
6 |
from graphgen.utils import detect_main_language, logger
|
7 |
|
8 |
|
@@ -13,9 +14,9 @@ class WikiSearch:
|
|
13 |
assert language in ["en", "zh"], "Only support English and Chinese"
|
14 |
set_lang(language)
|
15 |
|
16 |
-
async def search(self, query: str) -> Union[List[str], None]:
|
17 |
self.set_language(detect_main_language(query))
|
18 |
-
return wikipedia.search(query)
|
19 |
|
20 |
async def summary(self, query: str) -> Union[str, None]:
|
21 |
self.set_language(detect_main_language(query))
|
|
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
from typing import List, Union
|
3 |
|
4 |
import wikipedia
|
5 |
from wikipedia import set_lang
|
6 |
+
|
7 |
from graphgen.utils import detect_main_language, logger
|
8 |
|
9 |
|
|
|
14 |
assert language in ["en", "zh"], "Only support English and Chinese"
|
15 |
set_lang(language)
|
16 |
|
17 |
+
async def search(self, query: str, num_results: int = 1) -> Union[List[str], None]:
|
18 |
self.set_language(detect_main_language(query))
|
19 |
+
return wikipedia.search(query, results=num_results, suggestion=False)
|
20 |
|
21 |
async def summary(self, query: str) -> Union[str, None]:
|
22 |
self.set_language(detect_main_language(query))
|
graphgen/models/search/web/__init__.py
ADDED
File without changes
|
graphgen/models/search/web/bing_search.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import requests
|
4 |
+
from fastapi import HTTPException
|
5 |
+
|
6 |
+
from graphgen.utils import logger
|
7 |
+
|
8 |
+
BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
|
9 |
+
BING_MKT = "en-US"
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class BingSearch:
|
14 |
+
"""
|
15 |
+
Bing Search client to search with Bing.
|
16 |
+
"""
|
17 |
+
|
18 |
+
subscription_key: str
|
19 |
+
|
20 |
+
def search(self, query: str, num_results: int = 1):
|
21 |
+
"""
|
22 |
+
Search with Bing and return the contexts.
|
23 |
+
:param query: The search query.
|
24 |
+
:param num_results: The number of results to return.
|
25 |
+
:return: A list of search results.
|
26 |
+
"""
|
27 |
+
params = {"q": query, "mkt": BING_MKT, "count": num_results}
|
28 |
+
response = requests.get(
|
29 |
+
BING_SEARCH_V7_ENDPOINT,
|
30 |
+
headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
|
31 |
+
params=params,
|
32 |
+
timeout=10,
|
33 |
+
)
|
34 |
+
if not response.ok:
|
35 |
+
logger.error("Search engine error: %s", response.text)
|
36 |
+
raise HTTPException(response.status_code, "Search engine error.")
|
37 |
+
json_content = response.json()
|
38 |
+
try:
|
39 |
+
contexts = json_content["webPages"]["value"][:num_results]
|
40 |
+
except KeyError:
|
41 |
+
logger.error("Error encountered: %s", json_content)
|
42 |
+
return []
|
43 |
+
return contexts
|
graphgen/models/search/web/google_search.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import requests
|
4 |
+
from fastapi import HTTPException
|
5 |
+
|
6 |
+
from graphgen.utils import logger
|
7 |
+
|
8 |
+
GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class GoogleSearch:
|
13 |
+
def __init__(self, subscription_key: str, cx: str):
|
14 |
+
"""
|
15 |
+
Initialize the Google Search client with the subscription key and custom search engine ID.
|
16 |
+
:param subscription_key: Your Google API subscription key.
|
17 |
+
:param cx: Your custom search engine ID.
|
18 |
+
"""
|
19 |
+
self.subscription_key = subscription_key
|
20 |
+
self.cx = cx
|
21 |
+
|
22 |
+
def search(self, query: str, num_results: int = 1):
|
23 |
+
"""
|
24 |
+
Search with Google and return the contexts.
|
25 |
+
:param query: The search query.
|
26 |
+
:param num_results: The number of results to return.
|
27 |
+
:return: A list of search results.
|
28 |
+
"""
|
29 |
+
params = {
|
30 |
+
"key": self.subscription_key,
|
31 |
+
"cx": self.cx,
|
32 |
+
"q": query,
|
33 |
+
"num": num_results,
|
34 |
+
}
|
35 |
+
response = requests.get(GOOGLE_SEARCH_ENDPOINT, params=params, timeout=10)
|
36 |
+
if not response.ok:
|
37 |
+
logger.error("Search engine error: %s", response.text)
|
38 |
+
raise HTTPException(response.status_code, "Search engine error.")
|
39 |
+
json_content = response.json()
|
40 |
+
try:
|
41 |
+
contexts = json_content["items"][:num_results]
|
42 |
+
except KeyError:
|
43 |
+
logger.error("Error encountered: %s", json_content)
|
44 |
+
return []
|
45 |
+
return contexts
|
graphgen/models/storage/base_storage.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
-
from typing import
|
|
|
3 |
from graphgen.models.embed.embedding import EmbeddingFunc
|
4 |
|
5 |
T = TypeVar("T")
|
6 |
|
|
|
7 |
@dataclass
|
8 |
class StorageNameSpace:
|
9 |
working_dir: str = None
|
@@ -17,9 +19,25 @@ class StorageNameSpace:
|
|
17 |
|
18 |
|
19 |
@dataclass
|
20 |
-
class
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
|
|
|
|
23 |
async def all_keys(self) -> list[str]:
|
24 |
raise NotImplementedError
|
25 |
|
@@ -41,6 +59,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
|
|
41 |
async def drop(self):
|
42 |
raise NotImplementedError
|
43 |
|
|
|
44 |
@dataclass
|
45 |
class BaseGraphStorage(StorageNameSpace):
|
46 |
embedding_func: EmbeddingFunc = None
|
@@ -71,7 +90,9 @@ class BaseGraphStorage(StorageNameSpace):
|
|
71 |
) -> Union[dict, None]:
|
72 |
raise NotImplementedError
|
73 |
|
74 |
-
async def update_edge(
|
|
|
|
|
75 |
raise NotImplementedError
|
76 |
|
77 |
async def get_all_edges(self) -> Union[list[dict], None]:
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
from typing import Generic, TypeVar, Union
|
3 |
+
|
4 |
from graphgen.models.embed.embedding import EmbeddingFunc
|
5 |
|
6 |
T = TypeVar("T")
|
7 |
|
8 |
+
|
9 |
@dataclass
|
10 |
class StorageNameSpace:
|
11 |
working_dir: str = None
|
|
|
19 |
|
20 |
|
21 |
@dataclass
|
22 |
+
class BaseListStorage(Generic[T], StorageNameSpace):
|
23 |
+
async def all_items(self) -> list[T]:
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
async def get_by_index(self, index: int) -> Union[T, None]:
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
async def append(self, data: T):
|
30 |
+
raise NotImplementedError
|
31 |
+
|
32 |
+
async def upsert(self, data: list[T]):
|
33 |
+
raise NotImplementedError
|
34 |
+
|
35 |
+
async def drop(self):
|
36 |
+
raise NotImplementedError
|
37 |
+
|
38 |
|
39 |
+
@dataclass
|
40 |
+
class BaseKVStorage(Generic[T], StorageNameSpace):
|
41 |
async def all_keys(self) -> list[str]:
|
42 |
raise NotImplementedError
|
43 |
|
|
|
59 |
async def drop(self):
|
60 |
raise NotImplementedError
|
61 |
|
62 |
+
|
63 |
@dataclass
|
64 |
class BaseGraphStorage(StorageNameSpace):
|
65 |
embedding_func: EmbeddingFunc = None
|
|
|
90 |
) -> Union[dict, None]:
|
91 |
raise NotImplementedError
|
92 |
|
93 |
+
async def update_edge(
|
94 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
95 |
+
):
|
96 |
raise NotImplementedError
|
97 |
|
98 |
async def get_all_edges(self) -> Union[list[dict], None]:
|
graphgen/models/storage/json_storage.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import os
|
2 |
-
|
3 |
from dataclasses import dataclass
|
4 |
-
|
5 |
-
from graphgen.models.storage.base_storage import BaseKVStorage
|
|
|
6 |
|
7 |
|
8 |
@dataclass
|
@@ -49,3 +49,39 @@ class JsonKVStorage(BaseKVStorage):
|
|
49 |
|
50 |
async def drop(self):
|
51 |
self._data = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
from dataclasses import dataclass
|
3 |
+
|
4 |
+
from graphgen.models.storage.base_storage import BaseKVStorage, BaseListStorage
|
5 |
+
from graphgen.utils import load_json, logger, write_json
|
6 |
|
7 |
|
8 |
@dataclass
|
|
|
49 |
|
50 |
async def drop(self):
|
51 |
self._data = {}
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class JsonListStorage(BaseListStorage):
|
56 |
+
_data: list = None
|
57 |
+
|
58 |
+
def __post_init__(self):
|
59 |
+
self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
|
60 |
+
self._data = load_json(self._file_name) or []
|
61 |
+
logger.info("Load List %s with %d data", self.namespace, len(self._data))
|
62 |
+
|
63 |
+
@property
|
64 |
+
def data(self):
|
65 |
+
return self._data
|
66 |
+
|
67 |
+
async def all_items(self) -> list:
|
68 |
+
return self._data
|
69 |
+
|
70 |
+
async def index_done_callback(self):
|
71 |
+
write_json(self._data, self._file_name)
|
72 |
+
|
73 |
+
async def get_by_index(self, index: int):
|
74 |
+
if index < 0 or index >= len(self._data):
|
75 |
+
return None
|
76 |
+
return self._data[index]
|
77 |
+
|
78 |
+
async def append(self, data):
|
79 |
+
self._data.append(data)
|
80 |
+
|
81 |
+
async def upsert(self, data: list):
|
82 |
+
left_data = [d for d in data if d not in self._data]
|
83 |
+
self._data.extend(left_data)
|
84 |
+
return left_data
|
85 |
+
|
86 |
+
async def drop(self):
|
87 |
+
self._data = []
|
graphgen/models/vis/__init__.py
ADDED
File without changes
|
graphgen/models/vis/community_visualizer.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import networkx as nx
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class Visualizer:
|
10 |
+
"""
|
11 |
+
Class for visualizing graphs using NetworkX and Matplotlib.
|
12 |
+
"""
|
13 |
+
|
14 |
+
graph: nx.Graph = None
|
15 |
+
communities: Dict[str, int] = None
|
16 |
+
layout: str = "spring"
|
17 |
+
max_nodes: int = 1000
|
18 |
+
node_size: int = 10
|
19 |
+
alpha: float = 0.6
|
20 |
+
|
21 |
+
def visualize(self, save_path: str = None):
|
22 |
+
n = self.graph.number_of_nodes()
|
23 |
+
if self.layout == "spring":
|
24 |
+
k = max(0.1, 1.0 / (n**0.5))
|
25 |
+
pos = nx.spring_layout(self.graph, k=k, seed=42)
|
26 |
+
else:
|
27 |
+
raise ValueError(f"Unknown layout: {self.layout}")
|
28 |
+
|
29 |
+
plt.figure(figsize=(10, 10))
|
30 |
+
|
31 |
+
node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
|
32 |
+
|
33 |
+
nx.draw_networkx_nodes(
|
34 |
+
self.graph,
|
35 |
+
pos,
|
36 |
+
node_size=self.node_size,
|
37 |
+
node_color=node_colors,
|
38 |
+
cmap=plt.cm.tab20,
|
39 |
+
alpha=self.alpha,
|
40 |
+
)
|
41 |
+
nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2)
|
42 |
+
plt.axis("off")
|
43 |
+
|
44 |
+
if save_path:
|
45 |
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
46 |
+
print("Saved to", save_path)
|
47 |
+
else:
|
48 |
+
plt.show()
|
graphgen/operators/__init__.py
CHANGED
@@ -1,16 +1,22 @@
|
|
1 |
-
from .
|
|
|
|
|
|
|
|
|
2 |
from .quiz import quiz
|
3 |
-
from .
|
4 |
-
|
5 |
-
|
|
|
|
|
6 |
|
7 |
__all__ = [
|
8 |
"extract_kg",
|
9 |
"quiz",
|
10 |
"judge_statement",
|
11 |
-
"
|
12 |
-
"search_wikipedia",
|
13 |
"traverse_graph_by_edge",
|
14 |
"traverse_graph_atomically",
|
15 |
-
"traverse_graph_for_multi_hop"
|
|
|
16 |
]
|
|
|
1 |
+
from graphgen.operators.generate.generate_cot import generate_cot
|
2 |
+
from graphgen.operators.kg.extract_kg import extract_kg
|
3 |
+
from graphgen.operators.search.search_all import search_all
|
4 |
+
|
5 |
+
from .judge import judge_statement
|
6 |
from .quiz import quiz
|
7 |
+
from .traverse_graph import (
|
8 |
+
traverse_graph_atomically,
|
9 |
+
traverse_graph_by_edge,
|
10 |
+
traverse_graph_for_multi_hop,
|
11 |
+
)
|
12 |
|
13 |
__all__ = [
|
14 |
"extract_kg",
|
15 |
"quiz",
|
16 |
"judge_statement",
|
17 |
+
"search_all",
|
|
|
18 |
"traverse_graph_by_edge",
|
19 |
"traverse_graph_atomically",
|
20 |
+
"traverse_graph_for_multi_hop",
|
21 |
+
"generate_cot",
|
22 |
]
|
graphgen/operators/generate/__init__.py
ADDED
File without changes
|
graphgen/operators/generate/generate_cot.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
|
4 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
+
|
6 |
+
from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIModel
|
7 |
+
from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
|
8 |
+
from graphgen.utils import compute_content_hash, detect_main_language
|
9 |
+
|
10 |
+
|
11 |
+
async def generate_cot(
|
12 |
+
graph_storage: NetworkXStorage,
|
13 |
+
synthesizer_llm_client: OpenAIModel,
|
14 |
+
method_params: Dict = None,
|
15 |
+
):
|
16 |
+
method = method_params.get("method", "leiden")
|
17 |
+
detector = CommunityDetector(
|
18 |
+
graph_storage=graph_storage, method=method, method_params=method_params
|
19 |
+
)
|
20 |
+
|
21 |
+
results = await detector.detect_communities()
|
22 |
+
|
23 |
+
# Convert results to a format suitable for summarization
|
24 |
+
communities = {}
|
25 |
+
for node, community_id in results.items():
|
26 |
+
if community_id not in communities:
|
27 |
+
communities[community_id] = []
|
28 |
+
communities[community_id].append(node)
|
29 |
+
|
30 |
+
if not communities:
|
31 |
+
return {}
|
32 |
+
|
33 |
+
semaphore = asyncio.Semaphore(value=1000)
|
34 |
+
|
35 |
+
async def _generate_from_single_community(
|
36 |
+
c_id: int, nodes: List[str]
|
37 |
+
) -> Tuple[int, Tuple[str, str, str]]:
|
38 |
+
"""Summarize a single community."""
|
39 |
+
async with semaphore:
|
40 |
+
entities: List[str] = []
|
41 |
+
relationships: List[str] = []
|
42 |
+
|
43 |
+
for n in nodes:
|
44 |
+
node_data = await graph_storage.get_node(n)
|
45 |
+
if node_data is not None:
|
46 |
+
entities.append(f"({n}: {node_data.get('description')})")
|
47 |
+
|
48 |
+
edges = await graph_storage.get_node_edges(n)
|
49 |
+
for edge in edges:
|
50 |
+
target = edge[1]
|
51 |
+
if target in nodes:
|
52 |
+
edge_data = await graph_storage.get_edge(n, target)
|
53 |
+
relationships.append(
|
54 |
+
f"({n}) - [{edge_data['description']}] -> ({target})"
|
55 |
+
)
|
56 |
+
|
57 |
+
entities_str = "\n".join(entities)
|
58 |
+
relationships_str = "\n".join(relationships)
|
59 |
+
|
60 |
+
language = (
|
61 |
+
"English"
|
62 |
+
if detect_main_language(entities_str + relationships_str) == "en"
|
63 |
+
else "Chinese"
|
64 |
+
)
|
65 |
+
|
66 |
+
prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format(
|
67 |
+
entities=entities_str,
|
68 |
+
relationships=relationships_str,
|
69 |
+
)
|
70 |
+
|
71 |
+
cot_template = await synthesizer_llm_client.generate_answer(prompt)
|
72 |
+
|
73 |
+
if "问题:" in cot_template and "推理路径设计:" in cot_template:
|
74 |
+
question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip()
|
75 |
+
reasoning_path = cot_template.split("推理路径设计:")[1].strip()
|
76 |
+
elif (
|
77 |
+
"Question:" in cot_template and "Reasoning-Path Design:" in cot_template
|
78 |
+
):
|
79 |
+
question = (
|
80 |
+
cot_template.split("Question:")[1]
|
81 |
+
.split("Reasoning-Path Design:")[0]
|
82 |
+
.strip()
|
83 |
+
)
|
84 |
+
reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip()
|
85 |
+
else:
|
86 |
+
raise ValueError("COT template format is incorrect.")
|
87 |
+
|
88 |
+
prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format(
|
89 |
+
entities=entities_str,
|
90 |
+
relationships=relationships_str,
|
91 |
+
question=question,
|
92 |
+
reasoning_template=reasoning_path,
|
93 |
+
)
|
94 |
+
|
95 |
+
cot_answer = await synthesizer_llm_client.generate_answer(prompt)
|
96 |
+
|
97 |
+
return c_id, (question, reasoning_path, cot_answer)
|
98 |
+
|
99 |
+
cid_nodes = list(communities.items())
|
100 |
+
|
101 |
+
results: Dict = {}
|
102 |
+
async for coro in tqdm_async(
|
103 |
+
asyncio.as_completed(
|
104 |
+
[_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes]
|
105 |
+
),
|
106 |
+
total=len(cid_nodes),
|
107 |
+
desc="[Generating COT] Generating CoT data from communities",
|
108 |
+
unit="community",
|
109 |
+
):
|
110 |
+
cid, (q, r, a) = await coro
|
111 |
+
results[compute_content_hash(q)] = {
|
112 |
+
"question": q,
|
113 |
+
"reasoning_path": r,
|
114 |
+
"answer": a,
|
115 |
+
}
|
116 |
+
|
117 |
+
return results
|
graphgen/operators/judge.py
CHANGED
@@ -1,17 +1,20 @@
|
|
1 |
-
import math
|
2 |
import asyncio
|
|
|
|
|
3 |
from tqdm.asyncio import tqdm as tqdm_async
|
4 |
-
|
5 |
-
from graphgen.
|
6 |
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
|
|
|
7 |
|
8 |
|
9 |
-
async def judge_statement(
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
15 |
"""
|
16 |
Get all edges and nodes and judge them
|
17 |
|
@@ -34,7 +37,12 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
34 |
edge_data = edge[2]
|
35 |
|
36 |
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
|
37 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
38 |
return source_id, target_id, edge_data
|
39 |
|
40 |
description = edge_data["description"]
|
@@ -47,17 +55,27 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
47 |
gts = [gt for _, gt in descriptions]
|
48 |
for description, gt in descriptions:
|
49 |
judgement = await trainee_llm_client.generate_topk_per_token(
|
50 |
-
STATEMENT_JUDGEMENT_PROMPT[
|
|
|
|
|
51 |
)
|
52 |
judgements.append(judgement[0].top_candidates)
|
53 |
|
54 |
loss = yes_no_loss_entropy(judgements, gts)
|
55 |
|
56 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
edge_data["loss"] = loss
|
59 |
-
except Exception as e:
|
60 |
-
logger.error(
|
|
|
|
|
61 |
logger.info("Use default loss 0.1")
|
62 |
edge_data["loss"] = -math.log(0.1)
|
63 |
|
@@ -68,9 +86,9 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
68 |
|
69 |
results = []
|
70 |
for result in tqdm_async(
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
):
|
75 |
results.append(await result)
|
76 |
|
@@ -82,7 +100,9 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
82 |
node_data = node[1]
|
83 |
|
84 |
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
|
85 |
-
logger.info(
|
|
|
|
|
86 |
return node_id, node_data
|
87 |
|
88 |
description = node_data["description"]
|
@@ -95,16 +115,20 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
95 |
gts = [gt for _, gt in descriptions]
|
96 |
for description, gt in descriptions:
|
97 |
judgement = await trainee_llm_client.generate_topk_per_token(
|
98 |
-
STATEMENT_JUDGEMENT_PROMPT[
|
|
|
|
|
99 |
)
|
100 |
judgements.append(judgement[0].top_candidates)
|
101 |
|
102 |
loss = yes_no_loss_entropy(judgements, gts)
|
103 |
|
104 |
-
logger.info(
|
|
|
|
|
105 |
|
106 |
node_data["loss"] = loss
|
107 |
-
except Exception as e:
|
108 |
logger.error("Error in judging entity %s: %s", node_id, e)
|
109 |
logger.info("Use default loss 0.1")
|
110 |
node_data["loss"] = -math.log(0.1)
|
@@ -116,72 +140,9 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
116 |
|
117 |
results = []
|
118 |
for result in tqdm_async(
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
):
|
123 |
-
results.append(await result)
|
124 |
-
|
125 |
-
return graph_storage
|
126 |
-
|
127 |
-
async def skip_judge_statement(
|
128 |
-
graph_storage: NetworkXStorage,
|
129 |
-
max_concurrent: int = 1000
|
130 |
-
):
|
131 |
-
"""
|
132 |
-
Skip the judgement of the statement
|
133 |
-
:param graph_storage: graph storage instance
|
134 |
-
:param max_concurrent: max concurrent
|
135 |
-
:return:
|
136 |
-
"""
|
137 |
-
semaphore = asyncio.Semaphore(max_concurrent)
|
138 |
-
|
139 |
-
async def _skip_single_relation(
|
140 |
-
edge: tuple,
|
141 |
-
):
|
142 |
-
async with semaphore:
|
143 |
-
source_id = edge[0]
|
144 |
-
target_id = edge[1]
|
145 |
-
edge_data = edge[2]
|
146 |
-
|
147 |
-
if "loss" in edge_data and edge_data["loss"] is not None:
|
148 |
-
logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"])
|
149 |
-
return source_id, target_id, edge_data
|
150 |
-
|
151 |
-
edge_data["loss"] = -math.log(0.1)
|
152 |
-
await graph_storage.update_edge(source_id, target_id, edge_data)
|
153 |
-
return source_id, target_id, edge_data
|
154 |
-
|
155 |
-
edges = await graph_storage.get_all_edges()
|
156 |
-
results = []
|
157 |
-
for result in tqdm_async(
|
158 |
-
asyncio.as_completed([_skip_single_relation(edge) for edge in edges]),
|
159 |
-
total=len(edges),
|
160 |
-
desc="Skipping judgement of relations"
|
161 |
-
):
|
162 |
-
results.append(await result)
|
163 |
-
|
164 |
-
async def _skip_single_entity(
|
165 |
-
node: tuple,
|
166 |
-
):
|
167 |
-
async with semaphore:
|
168 |
-
node_id = node[0]
|
169 |
-
node_data = node[1]
|
170 |
-
|
171 |
-
if "loss" in node_data and node_data["loss"] is not None:
|
172 |
-
logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"])
|
173 |
-
return node_id, node_data
|
174 |
-
|
175 |
-
node_data["loss"] = -math.log(0.1)
|
176 |
-
await graph_storage.update_node(node_id, node_data)
|
177 |
-
return node_id, node_data
|
178 |
-
|
179 |
-
nodes = await graph_storage.get_all_nodes()
|
180 |
-
results = []
|
181 |
-
for result in tqdm_async(
|
182 |
-
asyncio.as_completed([_skip_single_entity(node) for node in nodes]),
|
183 |
-
total=len(nodes),
|
184 |
-
desc="Skipping judgement of entities"
|
185 |
):
|
186 |
results.append(await result)
|
187 |
|
|
|
|
|
1 |
import asyncio
|
2 |
+
import math
|
3 |
+
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
+
|
6 |
+
from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIModel
|
7 |
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
|
8 |
+
from graphgen.utils import logger, yes_no_loss_entropy
|
9 |
|
10 |
|
11 |
+
async def judge_statement( # pylint: disable=too-many-statements
|
12 |
+
trainee_llm_client: OpenAIModel,
|
13 |
+
graph_storage: NetworkXStorage,
|
14 |
+
rephrase_storage: JsonKVStorage,
|
15 |
+
re_judge: bool = False,
|
16 |
+
max_concurrent: int = 1000,
|
17 |
+
) -> NetworkXStorage:
|
18 |
"""
|
19 |
Get all edges and nodes and judge them
|
20 |
|
|
|
37 |
edge_data = edge[2]
|
38 |
|
39 |
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
|
40 |
+
logger.info(
|
41 |
+
"Edge %s -> %s already judged, loss: %s, skip",
|
42 |
+
source_id,
|
43 |
+
target_id,
|
44 |
+
edge_data["loss"],
|
45 |
+
)
|
46 |
return source_id, target_id, edge_data
|
47 |
|
48 |
description = edge_data["description"]
|
|
|
55 |
gts = [gt for _, gt in descriptions]
|
56 |
for description, gt in descriptions:
|
57 |
judgement = await trainee_llm_client.generate_topk_per_token(
|
58 |
+
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
|
59 |
+
statement=description
|
60 |
+
)
|
61 |
)
|
62 |
judgements.append(judgement[0].top_candidates)
|
63 |
|
64 |
loss = yes_no_loss_entropy(judgements, gts)
|
65 |
|
66 |
+
logger.info(
|
67 |
+
"Edge %s -> %s description: %s loss: %s",
|
68 |
+
source_id,
|
69 |
+
target_id,
|
70 |
+
description,
|
71 |
+
loss,
|
72 |
+
)
|
73 |
|
74 |
edge_data["loss"] = loss
|
75 |
+
except Exception as e: # pylint: disable=broad-except
|
76 |
+
logger.error(
|
77 |
+
"Error in judging relation %s -> %s: %s", source_id, target_id, e
|
78 |
+
)
|
79 |
logger.info("Use default loss 0.1")
|
80 |
edge_data["loss"] = -math.log(0.1)
|
81 |
|
|
|
86 |
|
87 |
results = []
|
88 |
for result in tqdm_async(
|
89 |
+
asyncio.as_completed([_judge_single_relation(edge) for edge in edges]),
|
90 |
+
total=len(edges),
|
91 |
+
desc="Judging relations",
|
92 |
):
|
93 |
results.append(await result)
|
94 |
|
|
|
100 |
node_data = node[1]
|
101 |
|
102 |
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
|
103 |
+
logger.info(
|
104 |
+
"Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
|
105 |
+
)
|
106 |
return node_id, node_data
|
107 |
|
108 |
description = node_data["description"]
|
|
|
115 |
gts = [gt for _, gt in descriptions]
|
116 |
for description, gt in descriptions:
|
117 |
judgement = await trainee_llm_client.generate_topk_per_token(
|
118 |
+
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
|
119 |
+
statement=description
|
120 |
+
)
|
121 |
)
|
122 |
judgements.append(judgement[0].top_candidates)
|
123 |
|
124 |
loss = yes_no_loss_entropy(judgements, gts)
|
125 |
|
126 |
+
logger.info(
|
127 |
+
"Node %s description: %s loss: %s", node_id, description, loss
|
128 |
+
)
|
129 |
|
130 |
node_data["loss"] = loss
|
131 |
+
except Exception as e: # pylint: disable=broad-except
|
132 |
logger.error("Error in judging entity %s: %s", node_id, e)
|
133 |
logger.info("Use default loss 0.1")
|
134 |
node_data["loss"] = -math.log(0.1)
|
|
|
140 |
|
141 |
results = []
|
142 |
for result in tqdm_async(
|
143 |
+
asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
|
144 |
+
total=len(nodes),
|
145 |
+
desc="Judging entities",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
):
|
147 |
results.append(await result)
|
148 |
|
graphgen/operators/kg/__init__.py
ADDED
File without changes
|
graphgen/operators/{extract_kg.py → kg/extract_kg.py}
RENAMED
@@ -1,27 +1,33 @@
|
|
1 |
-
import re
|
2 |
import asyncio
|
3 |
-
|
4 |
from collections import defaultdict
|
|
|
5 |
|
6 |
import gradio as gr
|
7 |
from tqdm.asyncio import tqdm as tqdm_async
|
|
|
8 |
from graphgen.models import Chunk, OpenAIModel, Tokenizer
|
9 |
from graphgen.models.storage.base_storage import BaseGraphStorage
|
|
|
10 |
from graphgen.templates import KG_EXTRACTION_PROMPT
|
11 |
-
from graphgen.utils import (
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
# pylint: disable=too-many-statements
|
18 |
async def extract_kg(
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
):
|
26 |
"""
|
27 |
:param llm_client: Synthesizer LLM model to extract entities and relationships
|
@@ -50,25 +56,25 @@ async def extract_kg(
|
|
50 |
)
|
51 |
|
52 |
final_result = await llm_client.generate_answer(hint_prompt)
|
53 |
-
logger.info(
|
54 |
|
55 |
history = pack_history_conversations(hint_prompt, final_result)
|
56 |
for loop_index in range(max_loop):
|
57 |
if_loop_result = await llm_client.generate_answer(
|
58 |
-
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"],
|
59 |
-
history=history
|
60 |
)
|
61 |
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
62 |
if if_loop_result != "yes":
|
63 |
break
|
64 |
|
65 |
glean_result = await llm_client.generate_answer(
|
66 |
-
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"],
|
67 |
-
history=history
|
68 |
)
|
69 |
-
logger.info(
|
70 |
|
71 |
-
history += pack_history_conversations(
|
|
|
|
|
72 |
final_result += glean_result
|
73 |
if loop_index == max_loop - 1:
|
74 |
break
|
@@ -76,8 +82,9 @@ async def extract_kg(
|
|
76 |
records = split_string_by_multi_markers(
|
77 |
final_result,
|
78 |
[
|
79 |
-
|
80 |
-
|
|
|
81 |
)
|
82 |
|
83 |
nodes = defaultdict(list)
|
@@ -87,16 +94,20 @@ async def extract_kg(
|
|
87 |
record = re.search(r"\((.*)\)", record)
|
88 |
if record is None:
|
89 |
continue
|
90 |
-
record = record.group(1)
|
91 |
record_attributes = split_string_by_multi_markers(
|
92 |
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
|
93 |
)
|
94 |
|
95 |
-
entity = await handle_single_entity_extraction(
|
|
|
|
|
96 |
if entity is not None:
|
97 |
nodes[entity["entity_name"]].append(entity)
|
98 |
continue
|
99 |
-
relation = await handle_single_relationship_extraction(
|
|
|
|
|
100 |
if relation is not None:
|
101 |
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
|
102 |
return dict(nodes), dict(edges)
|
@@ -106,17 +117,25 @@ async def extract_kg(
|
|
106 |
async for result in tqdm_async(
|
107 |
asyncio.as_completed([_process_single_content(c) for c in chunks]),
|
108 |
total=len(chunks),
|
109 |
-
desc="[
|
110 |
unit="chunk",
|
111 |
):
|
112 |
try:
|
113 |
if progress_bar is not None:
|
114 |
-
progress_bar(
|
|
|
|
|
|
|
115 |
results.append(await result)
|
116 |
if progress_bar is not None and len(results) == chunk_number:
|
117 |
-
progress_bar(
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
nodes = defaultdict(list)
|
122 |
edges = defaultdict(list)
|
|
|
|
|
1 |
import asyncio
|
2 |
+
import re
|
3 |
from collections import defaultdict
|
4 |
+
from typing import List
|
5 |
|
6 |
import gradio as gr
|
7 |
from tqdm.asyncio import tqdm as tqdm_async
|
8 |
+
|
9 |
from graphgen.models import Chunk, OpenAIModel, Tokenizer
|
10 |
from graphgen.models.storage.base_storage import BaseGraphStorage
|
11 |
+
from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
|
12 |
from graphgen.templates import KG_EXTRACTION_PROMPT
|
13 |
+
from graphgen.utils import (
|
14 |
+
detect_if_chinese,
|
15 |
+
handle_single_entity_extraction,
|
16 |
+
handle_single_relationship_extraction,
|
17 |
+
logger,
|
18 |
+
pack_history_conversations,
|
19 |
+
split_string_by_multi_markers,
|
20 |
+
)
|
21 |
|
22 |
|
23 |
# pylint: disable=too-many-statements
|
24 |
async def extract_kg(
|
25 |
+
llm_client: OpenAIModel,
|
26 |
+
kg_instance: BaseGraphStorage,
|
27 |
+
tokenizer_instance: Tokenizer,
|
28 |
+
chunks: List[Chunk],
|
29 |
+
progress_bar: gr.Progress = None,
|
30 |
+
max_concurrent: int = 1000,
|
31 |
):
|
32 |
"""
|
33 |
:param llm_client: Synthesizer LLM model to extract entities and relationships
|
|
|
56 |
)
|
57 |
|
58 |
final_result = await llm_client.generate_answer(hint_prompt)
|
59 |
+
logger.info("First result: %s", final_result)
|
60 |
|
61 |
history = pack_history_conversations(hint_prompt, final_result)
|
62 |
for loop_index in range(max_loop):
|
63 |
if_loop_result = await llm_client.generate_answer(
|
64 |
+
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
|
|
|
65 |
)
|
66 |
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
67 |
if if_loop_result != "yes":
|
68 |
break
|
69 |
|
70 |
glean_result = await llm_client.generate_answer(
|
71 |
+
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
|
|
|
72 |
)
|
73 |
+
logger.info("Loop %s glean: %s", loop_index, glean_result)
|
74 |
|
75 |
+
history += pack_history_conversations(
|
76 |
+
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
|
77 |
+
)
|
78 |
final_result += glean_result
|
79 |
if loop_index == max_loop - 1:
|
80 |
break
|
|
|
82 |
records = split_string_by_multi_markers(
|
83 |
final_result,
|
84 |
[
|
85 |
+
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
|
86 |
+
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
|
87 |
+
],
|
88 |
)
|
89 |
|
90 |
nodes = defaultdict(list)
|
|
|
94 |
record = re.search(r"\((.*)\)", record)
|
95 |
if record is None:
|
96 |
continue
|
97 |
+
record = record.group(1) # 提取括号内的内容
|
98 |
record_attributes = split_string_by_multi_markers(
|
99 |
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
|
100 |
)
|
101 |
|
102 |
+
entity = await handle_single_entity_extraction(
|
103 |
+
record_attributes, chunk_id
|
104 |
+
)
|
105 |
if entity is not None:
|
106 |
nodes[entity["entity_name"]].append(entity)
|
107 |
continue
|
108 |
+
relation = await handle_single_relationship_extraction(
|
109 |
+
record_attributes, chunk_id
|
110 |
+
)
|
111 |
if relation is not None:
|
112 |
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
|
113 |
return dict(nodes), dict(edges)
|
|
|
117 |
async for result in tqdm_async(
|
118 |
asyncio.as_completed([_process_single_content(c) for c in chunks]),
|
119 |
total=len(chunks),
|
120 |
+
desc="[2/4]Extracting entities and relationships from chunks",
|
121 |
unit="chunk",
|
122 |
):
|
123 |
try:
|
124 |
if progress_bar is not None:
|
125 |
+
progress_bar(
|
126 |
+
len(results) / chunk_number,
|
127 |
+
desc="[3/4]Extracting entities and relationships from chunks",
|
128 |
+
)
|
129 |
results.append(await result)
|
130 |
if progress_bar is not None and len(results) == chunk_number:
|
131 |
+
progress_bar(
|
132 |
+
1, desc="[3/4]Extracting entities and relationships from chunks"
|
133 |
+
)
|
134 |
+
except Exception as e: # pylint: disable=broad-except
|
135 |
+
logger.error(
|
136 |
+
"Error occurred while extracting entities and relationships from chunks: %s",
|
137 |
+
e,
|
138 |
+
)
|
139 |
|
140 |
nodes = defaultdict(list)
|
141 |
edges = defaultdict(list)
|
graphgen/operators/{merge_kg.py → kg/merge_kg.py}
RENAMED
@@ -1,19 +1,21 @@
|
|
1 |
-
from collections import Counter
|
2 |
import asyncio
|
|
|
|
|
3 |
from tqdm.asyncio import tqdm as tqdm_async
|
4 |
|
5 |
-
from graphgen.
|
6 |
-
from graphgen.utils import logger, detect_main_language
|
7 |
-
from graphgen.models import TopkTokenModel, Tokenizer
|
8 |
from graphgen.models.storage.base_storage import BaseGraphStorage
|
9 |
-
from graphgen.templates import
|
|
|
|
|
|
|
10 |
|
11 |
async def _handle_kg_summary(
|
12 |
entity_or_relation_name: str,
|
13 |
description: str,
|
14 |
llm_client: TopkTokenModel,
|
15 |
tokenizer_instance: Tokenizer,
|
16 |
-
max_summary_tokens: int = 200
|
17 |
) -> str:
|
18 |
"""
|
19 |
处理实体或关系的描述信息
|
@@ -33,17 +35,19 @@ async def _handle_kg_summary(
|
|
33 |
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
|
34 |
|
35 |
tokens = tokenizer_instance.encode_string(description)
|
36 |
-
if len(tokens) <
|
37 |
return description
|
38 |
|
39 |
use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens])
|
40 |
prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
|
41 |
entity_name=entity_or_relation_name,
|
42 |
-
description_list=use_description.split(
|
43 |
-
**KG_SUMMARIZATION_PROMPT["FORMAT"]
|
44 |
)
|
45 |
new_description = await llm_client.generate_answer(prompt)
|
46 |
-
logger.info(
|
|
|
|
|
47 |
return new_description
|
48 |
|
49 |
|
@@ -52,7 +56,7 @@ async def merge_nodes(
|
|
52 |
kg_instance: BaseGraphStorage,
|
53 |
llm_client: TopkTokenModel,
|
54 |
tokenizer_instance: Tokenizer,
|
55 |
-
max_concurrent: int = 1000
|
56 |
):
|
57 |
"""
|
58 |
Merge nodes
|
@@ -77,39 +81,34 @@ async def merge_nodes(
|
|
77 |
if node is not None:
|
78 |
entity_types.append(node["entity_type"])
|
79 |
source_ids.extend(
|
80 |
-
split_string_by_multi_markers(node["source_id"], [
|
81 |
)
|
82 |
descriptions.append(node["description"])
|
83 |
|
84 |
# 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
|
85 |
entity_type = sorted(
|
86 |
-
Counter(
|
87 |
-
[dp["entity_type"] for dp in node_data] + entity_types
|
88 |
-
).items(),
|
89 |
key=lambda x: x[1],
|
90 |
reverse=True,
|
91 |
)[0][0]
|
92 |
|
93 |
-
description =
|
94 |
sorted(set([dp["description"] for dp in node_data] + descriptions))
|
95 |
)
|
96 |
description = await _handle_kg_summary(
|
97 |
entity_name, description, llm_client, tokenizer_instance
|
98 |
)
|
99 |
|
100 |
-
source_id =
|
101 |
set([dp["source_id"] for dp in node_data] + source_ids)
|
102 |
)
|
103 |
|
104 |
node_data = {
|
105 |
"entity_type": entity_type,
|
106 |
"description": description,
|
107 |
-
"source_id": source_id
|
108 |
}
|
109 |
-
await kg_instance.upsert_node(
|
110 |
-
entity_name,
|
111 |
-
node_data=node_data
|
112 |
-
)
|
113 |
node_data["entity_name"] = entity_name
|
114 |
return node_data
|
115 |
|
@@ -125,7 +124,7 @@ async def merge_nodes(
|
|
125 |
):
|
126 |
try:
|
127 |
entities_data.append(await result)
|
128 |
-
except Exception as e:
|
129 |
logger.error("Error occurred while inserting entities into storage: %s", e)
|
130 |
|
131 |
|
@@ -134,7 +133,7 @@ async def merge_edges(
|
|
134 |
kg_instance: BaseGraphStorage,
|
135 |
llm_client: TopkTokenModel,
|
136 |
tokenizer_instance: Tokenizer,
|
137 |
-
max_concurrent: int = 1000
|
138 |
):
|
139 |
"""
|
140 |
Merge edges
|
@@ -157,14 +156,14 @@ async def merge_edges(
|
|
157 |
edge = await kg_instance.get_edge(src_id, tgt_id)
|
158 |
if edge is not None:
|
159 |
source_ids.extend(
|
160 |
-
split_string_by_multi_markers(edge["source_id"], [
|
161 |
)
|
162 |
descriptions.append(edge["description"])
|
163 |
|
164 |
-
description =
|
165 |
sorted(set([dp["description"] for dp in edge_data] + descriptions))
|
166 |
)
|
167 |
-
source_id =
|
168 |
set([dp["source_id"] for dp in edge_data] + source_ids)
|
169 |
)
|
170 |
|
@@ -175,8 +174,8 @@ async def merge_edges(
|
|
175 |
node_data={
|
176 |
"source_id": source_id,
|
177 |
"description": description,
|
178 |
-
"entity_type": "UNKNOWN"
|
179 |
-
}
|
180 |
)
|
181 |
|
182 |
description = await _handle_kg_summary(
|
@@ -186,24 +185,20 @@ async def merge_edges(
|
|
186 |
await kg_instance.upsert_edge(
|
187 |
src_id,
|
188 |
tgt_id,
|
189 |
-
edge_data={
|
190 |
-
"source_id": source_id,
|
191 |
-
"description": description
|
192 |
-
}
|
193 |
)
|
194 |
|
195 |
-
edge_data = {
|
196 |
-
"src_id": src_id,
|
197 |
-
"tgt_id": tgt_id,
|
198 |
-
"description": description
|
199 |
-
}
|
200 |
return edge_data
|
201 |
|
202 |
logger.info("Inserting relationships into storage...")
|
203 |
relationships_data = []
|
204 |
for result in tqdm_async(
|
205 |
asyncio.as_completed(
|
206 |
-
[
|
|
|
|
|
|
|
207 |
),
|
208 |
total=len(edges_data),
|
209 |
desc="Inserting relationships into storage",
|
@@ -211,5 +206,7 @@ async def merge_edges(
|
|
211 |
):
|
212 |
try:
|
213 |
relationships_data.append(await result)
|
214 |
-
except Exception as e:
|
215 |
-
logger.error(
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
+
from collections import Counter
|
3 |
+
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
+
from graphgen.models import Tokenizer, TopkTokenModel
|
|
|
|
|
7 |
from graphgen.models.storage.base_storage import BaseGraphStorage
|
8 |
+
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
|
9 |
+
from graphgen.utils import detect_main_language, logger
|
10 |
+
from graphgen.utils.format import split_string_by_multi_markers
|
11 |
+
|
12 |
|
13 |
async def _handle_kg_summary(
|
14 |
entity_or_relation_name: str,
|
15 |
description: str,
|
16 |
llm_client: TopkTokenModel,
|
17 |
tokenizer_instance: Tokenizer,
|
18 |
+
max_summary_tokens: int = 200,
|
19 |
) -> str:
|
20 |
"""
|
21 |
处理实体或关系的描述信息
|
|
|
35 |
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
|
36 |
|
37 |
tokens = tokenizer_instance.encode_string(description)
|
38 |
+
if len(tokens) < max_summary_tokens:
|
39 |
return description
|
40 |
|
41 |
use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens])
|
42 |
prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
|
43 |
entity_name=entity_or_relation_name,
|
44 |
+
description_list=use_description.split("<SEP>"),
|
45 |
+
**KG_SUMMARIZATION_PROMPT["FORMAT"],
|
46 |
)
|
47 |
new_description = await llm_client.generate_answer(prompt)
|
48 |
+
logger.info(
|
49 |
+
"Entity or relation %s summary: %s", entity_or_relation_name, new_description
|
50 |
+
)
|
51 |
return new_description
|
52 |
|
53 |
|
|
|
56 |
kg_instance: BaseGraphStorage,
|
57 |
llm_client: TopkTokenModel,
|
58 |
tokenizer_instance: Tokenizer,
|
59 |
+
max_concurrent: int = 1000,
|
60 |
):
|
61 |
"""
|
62 |
Merge nodes
|
|
|
81 |
if node is not None:
|
82 |
entity_types.append(node["entity_type"])
|
83 |
source_ids.extend(
|
84 |
+
split_string_by_multi_markers(node["source_id"], ["<SEP>"])
|
85 |
)
|
86 |
descriptions.append(node["description"])
|
87 |
|
88 |
# 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
|
89 |
entity_type = sorted(
|
90 |
+
Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
|
|
|
|
|
91 |
key=lambda x: x[1],
|
92 |
reverse=True,
|
93 |
)[0][0]
|
94 |
|
95 |
+
description = "<SEP>".join(
|
96 |
sorted(set([dp["description"] for dp in node_data] + descriptions))
|
97 |
)
|
98 |
description = await _handle_kg_summary(
|
99 |
entity_name, description, llm_client, tokenizer_instance
|
100 |
)
|
101 |
|
102 |
+
source_id = "<SEP>".join(
|
103 |
set([dp["source_id"] for dp in node_data] + source_ids)
|
104 |
)
|
105 |
|
106 |
node_data = {
|
107 |
"entity_type": entity_type,
|
108 |
"description": description,
|
109 |
+
"source_id": source_id,
|
110 |
}
|
111 |
+
await kg_instance.upsert_node(entity_name, node_data=node_data)
|
|
|
|
|
|
|
112 |
node_data["entity_name"] = entity_name
|
113 |
return node_data
|
114 |
|
|
|
124 |
):
|
125 |
try:
|
126 |
entities_data.append(await result)
|
127 |
+
except Exception as e: # pylint: disable=broad-except
|
128 |
logger.error("Error occurred while inserting entities into storage: %s", e)
|
129 |
|
130 |
|
|
|
133 |
kg_instance: BaseGraphStorage,
|
134 |
llm_client: TopkTokenModel,
|
135 |
tokenizer_instance: Tokenizer,
|
136 |
+
max_concurrent: int = 1000,
|
137 |
):
|
138 |
"""
|
139 |
Merge edges
|
|
|
156 |
edge = await kg_instance.get_edge(src_id, tgt_id)
|
157 |
if edge is not None:
|
158 |
source_ids.extend(
|
159 |
+
split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
|
160 |
)
|
161 |
descriptions.append(edge["description"])
|
162 |
|
163 |
+
description = "<SEP>".join(
|
164 |
sorted(set([dp["description"] for dp in edge_data] + descriptions))
|
165 |
)
|
166 |
+
source_id = "<SEP>".join(
|
167 |
set([dp["source_id"] for dp in edge_data] + source_ids)
|
168 |
)
|
169 |
|
|
|
174 |
node_data={
|
175 |
"source_id": source_id,
|
176 |
"description": description,
|
177 |
+
"entity_type": "UNKNOWN",
|
178 |
+
},
|
179 |
)
|
180 |
|
181 |
description = await _handle_kg_summary(
|
|
|
185 |
await kg_instance.upsert_edge(
|
186 |
src_id,
|
187 |
tgt_id,
|
188 |
+
edge_data={"source_id": source_id, "description": description},
|
|
|
|
|
|
|
189 |
)
|
190 |
|
191 |
+
edge_data = {"src_id": src_id, "tgt_id": tgt_id, "description": description}
|
|
|
|
|
|
|
|
|
192 |
return edge_data
|
193 |
|
194 |
logger.info("Inserting relationships into storage...")
|
195 |
relationships_data = []
|
196 |
for result in tqdm_async(
|
197 |
asyncio.as_completed(
|
198 |
+
[
|
199 |
+
process_single_edge(src_id, tgt_id, v)
|
200 |
+
for (src_id, tgt_id), v in edges_data.items()
|
201 |
+
]
|
202 |
),
|
203 |
total=len(edges_data),
|
204 |
desc="Inserting relationships into storage",
|
|
|
206 |
):
|
207 |
try:
|
208 |
relationships_data.append(await result)
|
209 |
+
except Exception as e: # pylint: disable=broad-except
|
210 |
+
logger.error(
|
211 |
+
"Error occurred while inserting relationships into storage: %s", e
|
212 |
+
)
|
graphgen/operators/{split_graph.py → kg/split_kg.py}
RENAMED
@@ -1,14 +1,16 @@
|
|
1 |
import random
|
2 |
from collections import defaultdict
|
|
|
3 |
from tqdm.asyncio import tqdm as tqdm_async
|
4 |
-
from graphgen.utils import logger
|
5 |
|
6 |
from graphgen.models import NetworkXStorage, TraverseStrategy
|
|
|
|
|
7 |
|
8 |
async def _get_node_info(
|
9 |
node_id: str,
|
10 |
graph_storage: NetworkXStorage,
|
11 |
-
)-> dict:
|
12 |
"""
|
13 |
Get node info
|
14 |
|
@@ -17,10 +19,7 @@ async def _get_node_info(
|
|
17 |
:return: node info
|
18 |
"""
|
19 |
node_data = await graph_storage.get_node(node_id)
|
20 |
-
return {
|
21 |
-
"node_id": node_id,
|
22 |
-
**node_data
|
23 |
-
}
|
24 |
|
25 |
|
26 |
def _get_level_n_edges_by_max_width(
|
@@ -33,7 +32,7 @@ def _get_level_n_edges_by_max_width(
|
|
33 |
bidirectional: bool,
|
34 |
max_extra_edges: int,
|
35 |
edge_sampling: str,
|
36 |
-
loss_strategy: str = "only_edge"
|
37 |
) -> list:
|
38 |
"""
|
39 |
Get level n edges for an edge.
|
@@ -71,10 +70,17 @@ def _get_level_n_edges_by_max_width(
|
|
71 |
|
72 |
if len(candidate_edges) >= max_extra_edges:
|
73 |
if loss_strategy == "both":
|
74 |
-
er_tuples = [
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
76 |
elif loss_strategy == "only_edge":
|
77 |
-
candidate_edges = _sort_edges(candidate_edges, edge_sampling)[
|
|
|
|
|
78 |
else:
|
79 |
raise ValueError(f"Invalid loss strategy: {loss_strategy}")
|
80 |
|
@@ -101,16 +107,16 @@ def _get_level_n_edges_by_max_width(
|
|
101 |
|
102 |
|
103 |
def _get_level_n_edges_by_max_tokens(
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
) -> list:
|
115 |
"""
|
116 |
Get level n edges for an edge.
|
@@ -129,8 +135,11 @@ def _get_level_n_edges_by_max_tokens(
|
|
129 |
"""
|
130 |
src_id, tgt_id, src_edge_data = src_edge
|
131 |
|
132 |
-
max_tokens -= (
|
133 |
-
|
|
|
|
|
|
|
134 |
|
135 |
level_n_edges = []
|
136 |
|
@@ -151,7 +160,10 @@ def _get_level_n_edges_by_max_tokens(
|
|
151 |
break
|
152 |
|
153 |
if loss_strategy == "both":
|
154 |
-
er_tuples = [
|
|
|
|
|
|
|
155 |
candidate_edges = _sort_tuples(er_tuples, edge_sampling)
|
156 |
elif loss_strategy == "only_edge":
|
157 |
candidate_edges = _sort_edges(candidate_edges, edge_sampling)
|
@@ -196,15 +208,22 @@ def _sort_tuples(er_tuples: list, edge_sampling: str) -> list:
|
|
196 |
if edge_sampling == "random":
|
197 |
er_tuples = random.sample(er_tuples, len(er_tuples))
|
198 |
elif edge_sampling == "min_loss":
|
199 |
-
er_tuples = sorted(
|
|
|
|
|
|
|
200 |
elif edge_sampling == "max_loss":
|
201 |
-
er_tuples = sorted(
|
202 |
-
|
|
|
|
|
|
|
203 |
else:
|
204 |
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
205 |
edges = [edge for _, edge in er_tuples]
|
206 |
return edges
|
207 |
|
|
|
208 |
def _sort_edges(edges: list, edge_sampling: str) -> list:
|
209 |
"""
|
210 |
Sort edges with edge sampling strategy
|
@@ -223,11 +242,12 @@ def _sort_edges(edges: list, edge_sampling: str) -> list:
|
|
223 |
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
224 |
return edges
|
225 |
|
226 |
-
|
|
|
227 |
nodes: list,
|
228 |
edges: list,
|
229 |
graph_storage: NetworkXStorage,
|
230 |
-
traverse_strategy: TraverseStrategy
|
231 |
):
|
232 |
expand_method = traverse_strategy.expand_method
|
233 |
if expand_method == "max_width":
|
@@ -256,7 +276,10 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
256 |
node_dict[node_name] = i
|
257 |
|
258 |
if traverse_strategy.loss_strategy == "both":
|
259 |
-
er_tuples = [
|
|
|
|
|
|
|
260 |
edges = _sort_tuples(er_tuples, edge_sampling)
|
261 |
elif traverse_strategy.loss_strategy == "only_edge":
|
262 |
edges = _sort_edges(edges, edge_sampling)
|
@@ -279,21 +302,36 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
279 |
src_id = edge[0]
|
280 |
tgt_id = edge[1]
|
281 |
|
282 |
-
_process_nodes.extend(
|
283 |
-
|
|
|
284 |
_process_edges.append(edge)
|
285 |
|
286 |
if expand_method == "max_width":
|
287 |
level_n_edges = _get_level_n_edges_by_max_width(
|
288 |
-
edge_adj_list,
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
)
|
292 |
else:
|
293 |
level_n_edges = _get_level_n_edges_by_max_tokens(
|
294 |
-
edge_adj_list,
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
)
|
298 |
|
299 |
for _edge in level_n_edges:
|
@@ -302,8 +340,12 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
302 |
_process_edges.append(_edge)
|
303 |
|
304 |
# 去重
|
305 |
-
_process_nodes = list(
|
306 |
-
|
|
|
|
|
|
|
|
|
307 |
|
308 |
processing_batches.append((_process_nodes, _process_edges))
|
309 |
|
@@ -312,15 +354,21 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
312 |
# isolate nodes
|
313 |
isolated_node_strategy = traverse_strategy.isolated_node_strategy
|
314 |
if isolated_node_strategy == "add":
|
315 |
-
processing_batches = await _add_isolated_nodes(
|
316 |
-
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
return processing_batches
|
319 |
|
|
|
320 |
async def _add_isolated_nodes(
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
) -> list:
|
325 |
visited_nodes = set()
|
326 |
for _process_nodes, _process_edges in processing_batches:
|
|
|
1 |
import random
|
2 |
from collections import defaultdict
|
3 |
+
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
|
|
5 |
|
6 |
from graphgen.models import NetworkXStorage, TraverseStrategy
|
7 |
+
from graphgen.utils import logger
|
8 |
+
|
9 |
|
10 |
async def _get_node_info(
|
11 |
node_id: str,
|
12 |
graph_storage: NetworkXStorage,
|
13 |
+
) -> dict:
|
14 |
"""
|
15 |
Get node info
|
16 |
|
|
|
19 |
:return: node info
|
20 |
"""
|
21 |
node_data = await graph_storage.get_node(node_id)
|
22 |
+
return {"node_id": node_id, **node_data}
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def _get_level_n_edges_by_max_width(
|
|
|
32 |
bidirectional: bool,
|
33 |
max_extra_edges: int,
|
34 |
edge_sampling: str,
|
35 |
+
loss_strategy: str = "only_edge",
|
36 |
) -> list:
|
37 |
"""
|
38 |
Get level n edges for an edge.
|
|
|
70 |
|
71 |
if len(candidate_edges) >= max_extra_edges:
|
72 |
if loss_strategy == "both":
|
73 |
+
er_tuples = [
|
74 |
+
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
75 |
+
for edge in candidate_edges
|
76 |
+
]
|
77 |
+
candidate_edges = _sort_tuples(er_tuples, edge_sampling)[
|
78 |
+
:max_extra_edges
|
79 |
+
]
|
80 |
elif loss_strategy == "only_edge":
|
81 |
+
candidate_edges = _sort_edges(candidate_edges, edge_sampling)[
|
82 |
+
:max_extra_edges
|
83 |
+
]
|
84 |
else:
|
85 |
raise ValueError(f"Invalid loss strategy: {loss_strategy}")
|
86 |
|
|
|
107 |
|
108 |
|
109 |
def _get_level_n_edges_by_max_tokens(
|
110 |
+
edge_adj_list: dict,
|
111 |
+
node_dict: dict,
|
112 |
+
edges: list,
|
113 |
+
nodes: list,
|
114 |
+
src_edge: tuple,
|
115 |
+
max_depth: int,
|
116 |
+
bidirectional: bool,
|
117 |
+
max_tokens: int,
|
118 |
+
edge_sampling: str,
|
119 |
+
loss_strategy: str = "only_edge",
|
120 |
) -> list:
|
121 |
"""
|
122 |
Get level n edges for an edge.
|
|
|
135 |
"""
|
136 |
src_id, tgt_id, src_edge_data = src_edge
|
137 |
|
138 |
+
max_tokens -= (
|
139 |
+
src_edge_data["length"]
|
140 |
+
+ nodes[node_dict[src_id]][1]["length"]
|
141 |
+
+ nodes[node_dict[tgt_id]][1]["length"]
|
142 |
+
)
|
143 |
|
144 |
level_n_edges = []
|
145 |
|
|
|
160 |
break
|
161 |
|
162 |
if loss_strategy == "both":
|
163 |
+
er_tuples = [
|
164 |
+
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
165 |
+
for edge in candidate_edges
|
166 |
+
]
|
167 |
candidate_edges = _sort_tuples(er_tuples, edge_sampling)
|
168 |
elif loss_strategy == "only_edge":
|
169 |
candidate_edges = _sort_edges(candidate_edges, edge_sampling)
|
|
|
208 |
if edge_sampling == "random":
|
209 |
er_tuples = random.sample(er_tuples, len(er_tuples))
|
210 |
elif edge_sampling == "min_loss":
|
211 |
+
er_tuples = sorted(
|
212 |
+
er_tuples,
|
213 |
+
key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
|
214 |
+
)
|
215 |
elif edge_sampling == "max_loss":
|
216 |
+
er_tuples = sorted(
|
217 |
+
er_tuples,
|
218 |
+
key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
|
219 |
+
reverse=True,
|
220 |
+
)
|
221 |
else:
|
222 |
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
223 |
edges = [edge for _, edge in er_tuples]
|
224 |
return edges
|
225 |
|
226 |
+
|
227 |
def _sort_edges(edges: list, edge_sampling: str) -> list:
|
228 |
"""
|
229 |
Sort edges with edge sampling strategy
|
|
|
242 |
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
243 |
return edges
|
244 |
|
245 |
+
|
246 |
+
async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
247 |
nodes: list,
|
248 |
edges: list,
|
249 |
graph_storage: NetworkXStorage,
|
250 |
+
traverse_strategy: TraverseStrategy,
|
251 |
):
|
252 |
expand_method = traverse_strategy.expand_method
|
253 |
if expand_method == "max_width":
|
|
|
276 |
node_dict[node_name] = i
|
277 |
|
278 |
if traverse_strategy.loss_strategy == "both":
|
279 |
+
er_tuples = [
|
280 |
+
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
281 |
+
for edge in edges
|
282 |
+
]
|
283 |
edges = _sort_tuples(er_tuples, edge_sampling)
|
284 |
elif traverse_strategy.loss_strategy == "only_edge":
|
285 |
edges = _sort_edges(edges, edge_sampling)
|
|
|
302 |
src_id = edge[0]
|
303 |
tgt_id = edge[1]
|
304 |
|
305 |
+
_process_nodes.extend(
|
306 |
+
[await get_cached_node_info(src_id), await get_cached_node_info(tgt_id)]
|
307 |
+
)
|
308 |
_process_edges.append(edge)
|
309 |
|
310 |
if expand_method == "max_width":
|
311 |
level_n_edges = _get_level_n_edges_by_max_width(
|
312 |
+
edge_adj_list,
|
313 |
+
node_dict,
|
314 |
+
edges,
|
315 |
+
nodes,
|
316 |
+
edge,
|
317 |
+
max_depth,
|
318 |
+
traverse_strategy.bidirectional,
|
319 |
+
traverse_strategy.max_extra_edges,
|
320 |
+
edge_sampling,
|
321 |
+
traverse_strategy.loss_strategy,
|
322 |
)
|
323 |
else:
|
324 |
level_n_edges = _get_level_n_edges_by_max_tokens(
|
325 |
+
edge_adj_list,
|
326 |
+
node_dict,
|
327 |
+
edges,
|
328 |
+
nodes,
|
329 |
+
edge,
|
330 |
+
max_depth,
|
331 |
+
traverse_strategy.bidirectional,
|
332 |
+
traverse_strategy.max_tokens,
|
333 |
+
edge_sampling,
|
334 |
+
traverse_strategy.loss_strategy,
|
335 |
)
|
336 |
|
337 |
for _edge in level_n_edges:
|
|
|
340 |
_process_edges.append(_edge)
|
341 |
|
342 |
# 去重
|
343 |
+
_process_nodes = list(
|
344 |
+
{node["node_id"]: node for node in _process_nodes}.values()
|
345 |
+
)
|
346 |
+
_process_edges = list(
|
347 |
+
{(edge[0], edge[1]): edge for edge in _process_edges}.values()
|
348 |
+
)
|
349 |
|
350 |
processing_batches.append((_process_nodes, _process_edges))
|
351 |
|
|
|
354 |
# isolate nodes
|
355 |
isolated_node_strategy = traverse_strategy.isolated_node_strategy
|
356 |
if isolated_node_strategy == "add":
|
357 |
+
processing_batches = await _add_isolated_nodes(
|
358 |
+
nodes, processing_batches, graph_storage
|
359 |
+
)
|
360 |
+
logger.info(
|
361 |
+
"Processing batches after adding isolated nodes: %d",
|
362 |
+
len(processing_batches),
|
363 |
+
)
|
364 |
|
365 |
return processing_batches
|
366 |
|
367 |
+
|
368 |
async def _add_isolated_nodes(
|
369 |
+
nodes: list,
|
370 |
+
processing_batches: list,
|
371 |
+
graph_storage: NetworkXStorage,
|
372 |
) -> list:
|
373 |
visited_nodes = set()
|
374 |
for _process_nodes, _process_edges in processing_batches:
|
graphgen/operators/preprocess/__init__.py
ADDED
File without changes
|
graphgen/operators/{resolute_coreference.py → preprocess/resolute_coreference.py}
RENAMED
@@ -1,12 +1,13 @@
|
|
1 |
from typing import List
|
2 |
-
|
3 |
-
from graphgen.models import OpenAIModel
|
4 |
-
from graphgen.templates import
|
5 |
from graphgen.utils import detect_main_language
|
6 |
|
|
|
7 |
async def resolute_coreference(
|
8 |
-
|
9 |
-
|
10 |
"""
|
11 |
Resolute conference
|
12 |
|
@@ -23,9 +24,8 @@ async def resolute_coreference(
|
|
23 |
for _, chunk in enumerate(chunks[1:]):
|
24 |
language = detect_main_language(chunk.content)
|
25 |
result = await llm_client.generate_answer(
|
26 |
-
|
27 |
-
reference
|
28 |
-
input_sentence = chunk.content
|
29 |
)
|
30 |
)
|
31 |
results.append(Chunk(id=chunk.id, content=result))
|
|
|
1 |
from typing import List
|
2 |
+
|
3 |
+
from graphgen.models import Chunk, OpenAIModel
|
4 |
+
from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
|
5 |
from graphgen.utils import detect_main_language
|
6 |
|
7 |
+
|
8 |
async def resolute_coreference(
|
9 |
+
llm_client: OpenAIModel, chunks: List[Chunk]
|
10 |
+
) -> List[Chunk]:
|
11 |
"""
|
12 |
Resolute conference
|
13 |
|
|
|
24 |
for _, chunk in enumerate(chunks[1:]):
|
25 |
language = detect_main_language(chunk.content)
|
26 |
result = await llm_client.generate_answer(
|
27 |
+
COREFERENCE_RESOLUTION_PROMPT[language].format(
|
28 |
+
reference=results[0].content, input_sentence=chunk.content
|
|
|
29 |
)
|
30 |
)
|
31 |
results.append(Chunk(id=chunk.id, content=result))
|
graphgen/operators/search/__init__.py
ADDED
File without changes
|
graphgen/operators/search/db/__init__.py
ADDED
File without changes
|
graphgen/operators/search/db/search_uniprot.py
ADDED
File without changes
|
graphgen/operators/search/kg/__init__.py
ADDED
File without changes
|
graphgen/operators/search/kg/search_wikipedia.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm.asyncio import tqdm_asyncio as tqdm_async
|
2 |
+
|
3 |
+
from graphgen.models import WikiSearch
|
4 |
+
from graphgen.utils import logger
|
5 |
+
|
6 |
+
|
7 |
+
async def _process_single_entity(
|
8 |
+
entity_name: str,
|
9 |
+
wiki_search_client: WikiSearch,
|
10 |
+
) -> str | None:
|
11 |
+
"""
|
12 |
+
Process single entity by searching Wikipedia
|
13 |
+
:param entity_name
|
14 |
+
:param wiki_search_client
|
15 |
+
:return: summary of the entity or None if not found
|
16 |
+
"""
|
17 |
+
search_results = await wiki_search_client.search(entity_name)
|
18 |
+
if not search_results:
|
19 |
+
return None
|
20 |
+
|
21 |
+
summary = None
|
22 |
+
try:
|
23 |
+
summary = await wiki_search_client.summary(search_results[-1])
|
24 |
+
logger.info(
|
25 |
+
"Entity %s search result: %s summary: %s",
|
26 |
+
entity_name,
|
27 |
+
str(search_results),
|
28 |
+
summary,
|
29 |
+
)
|
30 |
+
except Exception as e: # pylint: disable=broad-except
|
31 |
+
logger.error("Error processing entity %s: %s", entity_name, str(e))
|
32 |
+
|
33 |
+
return summary
|
34 |
+
|
35 |
+
|
36 |
+
async def search_wikipedia(
|
37 |
+
wiki_search_client: WikiSearch,
|
38 |
+
entities: set[str],
|
39 |
+
) -> dict:
|
40 |
+
"""
|
41 |
+
Search wikipedia for entities
|
42 |
+
|
43 |
+
:param wiki_search_client: wiki search client
|
44 |
+
:param entities: list of entities to search
|
45 |
+
:return: nodes with search results
|
46 |
+
"""
|
47 |
+
wiki_data = {}
|
48 |
+
|
49 |
+
async for entity in tqdm_async(
|
50 |
+
entities, desc="Searching Wikipedia", total=len(entities)
|
51 |
+
):
|
52 |
+
try:
|
53 |
+
summary = await _process_single_entity(entity, wiki_search_client)
|
54 |
+
if summary:
|
55 |
+
wiki_data[entity] = summary
|
56 |
+
except Exception as e: # pylint: disable=broad-except
|
57 |
+
logger.error("Error processing entity %s: %s", entity, str(e))
|
58 |
+
return wiki_data
|
graphgen/operators/search/search_all.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
To use Google Web Search API,
|
3 |
+
follow the instructions [here](https://developers.google.com/custom-search/v1/overview)
|
4 |
+
to get your Google search api key.
|
5 |
+
|
6 |
+
To use Bing Web Search API,
|
7 |
+
follow the instructions [here](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api)
|
8 |
+
and obtain your Bing subscription key.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
|
13 |
+
from graphgen.utils import logger
|
14 |
+
|
15 |
+
|
16 |
+
async def search_all(
|
17 |
+
search_types: dict, search_entities: set[str]
|
18 |
+
) -> dict[str, dict[str, str]]:
|
19 |
+
"""
|
20 |
+
:param search_types
|
21 |
+
:param search_entities: list of entities to search
|
22 |
+
:return: nodes with search results
|
23 |
+
"""
|
24 |
+
|
25 |
+
results = {}
|
26 |
+
|
27 |
+
for search_type in search_types:
|
28 |
+
if search_type == "wikipedia":
|
29 |
+
from graphgen.models import WikiSearch
|
30 |
+
from graphgen.operators.search.kg.search_wikipedia import search_wikipedia
|
31 |
+
|
32 |
+
wiki_search_client = WikiSearch()
|
33 |
+
|
34 |
+
wiki_results = await search_wikipedia(wiki_search_client, search_entities)
|
35 |
+
for entity_name, description in wiki_results.items():
|
36 |
+
if description:
|
37 |
+
results[entity_name] = {"wikipedia": description}
|
38 |
+
elif search_type == "google":
|
39 |
+
from graphgen.models import GoogleSearch
|
40 |
+
from graphgen.operators.search.web.search_google import search_google
|
41 |
+
|
42 |
+
google_search_client = GoogleSearch(
|
43 |
+
subscription_key=os.environ["GOOGLE_SEARCH_API_KEY"],
|
44 |
+
cx=os.environ["GOOGLE_SEARCH_CX"],
|
45 |
+
)
|
46 |
+
|
47 |
+
google_results = await search_google(google_search_client, search_entities)
|
48 |
+
for entity_name, description in google_results.items():
|
49 |
+
if description:
|
50 |
+
results[entity_name] = results.get(entity_name, {})
|
51 |
+
results[entity_name]["google"] = description
|
52 |
+
elif search_type == "bing":
|
53 |
+
from graphgen.models import BingSearch
|
54 |
+
from graphgen.operators.search.web.search_bing import search_bing
|
55 |
+
|
56 |
+
bing_search_client = BingSearch(
|
57 |
+
subscription_key=os.environ["BING_SEARCH_API_KEY"]
|
58 |
+
)
|
59 |
+
|
60 |
+
bing_results = await search_bing(bing_search_client, search_entities)
|
61 |
+
for entity_name, description in bing_results.items():
|
62 |
+
if description:
|
63 |
+
results[entity_name] = results.get(entity_name, {})
|
64 |
+
results[entity_name]["bing"] = description
|
65 |
+
elif search_type == "uniprot":
|
66 |
+
# from graphgen.models import UniProtSearch
|
67 |
+
# from graphgen.operators.search.db.search_uniprot import search_uniprot
|
68 |
+
#
|
69 |
+
# uniprot_search_client = UniProtSearch()
|
70 |
+
#
|
71 |
+
# uniprot_results = await search_uniprot(
|
72 |
+
# uniprot_search_client, search_entities
|
73 |
+
# )
|
74 |
+
raise NotImplementedError(
|
75 |
+
"Processing of UniProt search results is not implemented yet."
|
76 |
+
)
|
77 |
+
|
78 |
+
else:
|
79 |
+
logger.error("Search type %s is not supported yet.", search_type)
|
80 |
+
continue
|
81 |
+
|
82 |
+
return results
|
graphgen/operators/search/web/__init__.py
ADDED
File without changes
|
graphgen/operators/search/web/search_bing.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import trafilatura
|
2 |
+
from tqdm.asyncio import tqdm_asyncio as tqdm_async
|
3 |
+
|
4 |
+
from graphgen.models import BingSearch
|
5 |
+
from graphgen.utils import logger
|
6 |
+
|
7 |
+
|
8 |
+
async def _process_single_entity(
|
9 |
+
entity_name: str, bing_search_client: BingSearch
|
10 |
+
) -> str | None:
|
11 |
+
"""
|
12 |
+
Process single entity by searching Bing.
|
13 |
+
:param entity_name: The name of the entity to search.
|
14 |
+
:param bing_search_client: The Bing search client.
|
15 |
+
:return: Summary of the entity or None if not found.
|
16 |
+
"""
|
17 |
+
search_results = bing_search_client.search(entity_name)
|
18 |
+
if not search_results:
|
19 |
+
return None
|
20 |
+
|
21 |
+
# Get more details from the first search result
|
22 |
+
first_result = search_results[0]
|
23 |
+
content = trafilatura.fetch_url(first_result["url"])
|
24 |
+
summary = trafilatura.extract(content, include_comments=False, include_links=False)
|
25 |
+
summary = summary.strip()
|
26 |
+
logger.info(
|
27 |
+
"Entity %s search result: %s",
|
28 |
+
entity_name,
|
29 |
+
summary,
|
30 |
+
)
|
31 |
+
return summary
|
32 |
+
|
33 |
+
|
34 |
+
async def search_bing(
|
35 |
+
bing_search_client: BingSearch,
|
36 |
+
entities: set[str],
|
37 |
+
) -> dict[str, str]:
|
38 |
+
"""
|
39 |
+
Search with Bing and return the contexts.
|
40 |
+
:return:
|
41 |
+
"""
|
42 |
+
bing_data = {}
|
43 |
+
|
44 |
+
async for entity in tqdm_async(
|
45 |
+
entities, desc="Searching Bing", total=len(entities)
|
46 |
+
):
|
47 |
+
try:
|
48 |
+
summary = await _process_single_entity(entity, bing_search_client)
|
49 |
+
if summary:
|
50 |
+
bing_data[entity] = summary
|
51 |
+
except Exception as e: # pylint: disable=broad-except
|
52 |
+
logger.error("Error processing entity %s: %s", entity, str(e))
|
53 |
+
return bing_data
|
graphgen/operators/search/web/search_google.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import trafilatura
|
2 |
+
from tqdm.asyncio import tqdm_asyncio as tqdm_async
|
3 |
+
|
4 |
+
from graphgen.models import GoogleSearch
|
5 |
+
from graphgen.utils import logger
|
6 |
+
|
7 |
+
|
8 |
+
async def _process_single_entity(
|
9 |
+
entity_name: str, google_search_client: GoogleSearch
|
10 |
+
) -> str | None:
|
11 |
+
search_results = google_search_client.search(entity_name)
|
12 |
+
if not search_results:
|
13 |
+
return None
|
14 |
+
|
15 |
+
# Get more details from the first search result
|
16 |
+
first_result = search_results[0]
|
17 |
+
content = trafilatura.fetch_url(first_result["link"])
|
18 |
+
summary = trafilatura.extract(content, include_comments=False, include_links=False)
|
19 |
+
summary = summary.strip()
|
20 |
+
logger.info(
|
21 |
+
"Entity %s search result: %s",
|
22 |
+
entity_name,
|
23 |
+
summary,
|
24 |
+
)
|
25 |
+
return summary
|
26 |
+
|
27 |
+
|
28 |
+
async def search_google(
|
29 |
+
google_search_client: GoogleSearch,
|
30 |
+
entities: set[str],
|
31 |
+
) -> dict:
|
32 |
+
"""
|
33 |
+
Search with Google and return the contexts.
|
34 |
+
:param google_search_client: Google search client
|
35 |
+
:param entities: list of entities to search
|
36 |
+
:return:
|
37 |
+
"""
|
38 |
+
google_data = {}
|
39 |
+
|
40 |
+
async for entity in tqdm_async(
|
41 |
+
entities, desc="Searching Google", total=len(entities)
|
42 |
+
):
|
43 |
+
try:
|
44 |
+
summary = await _process_single_entity(entity, google_search_client)
|
45 |
+
if summary:
|
46 |
+
google_data[entity] = summary
|
47 |
+
except Exception as e: # pylint: disable=broad-except
|
48 |
+
logger.error("Error processing entity %s: %s", entity, str(e))
|
49 |
+
return google_data
|
graphgen/operators/search_wikipedia.py
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
import asyncio
|
2 |
-
from graphgen.models import WikiSearch, OpenAIModel
|
3 |
-
from graphgen.models.storage.base_storage import BaseGraphStorage
|
4 |
-
from graphgen.templates import SEARCH_JUDGEMENT_PROMPT
|
5 |
-
from graphgen.utils import logger
|
6 |
-
|
7 |
-
|
8 |
-
async def _process_single_entity(entity_name: str,
|
9 |
-
description: str,
|
10 |
-
llm_client: OpenAIModel,
|
11 |
-
wiki_search_client: WikiSearch) -> tuple[str, None] | tuple[str, str]:
|
12 |
-
"""
|
13 |
-
Process single entity
|
14 |
-
|
15 |
-
"""
|
16 |
-
search_results = await wiki_search_client.search(entity_name)
|
17 |
-
if not search_results:
|
18 |
-
return entity_name, None
|
19 |
-
examples = "\n".join(SEARCH_JUDGEMENT_PROMPT["EXAMPLES"])
|
20 |
-
search_results.append("None of the above")
|
21 |
-
|
22 |
-
search_results_str = "\n".join([f"{i + 1}. {sr}" for i, sr in enumerate(search_results)])
|
23 |
-
prompt = SEARCH_JUDGEMENT_PROMPT["TEMPLATE"].format(
|
24 |
-
examples=examples,
|
25 |
-
entity_name=entity_name,
|
26 |
-
description=description,
|
27 |
-
search_results=search_results_str,
|
28 |
-
)
|
29 |
-
response = await llm_client.generate_answer(prompt)
|
30 |
-
|
31 |
-
try:
|
32 |
-
response = response.strip()
|
33 |
-
response = int(response)
|
34 |
-
if response < 1 or response >= len(search_results):
|
35 |
-
response = None
|
36 |
-
else:
|
37 |
-
response = await wiki_search_client.summary(search_results[response - 1])
|
38 |
-
except ValueError:
|
39 |
-
response = None
|
40 |
-
|
41 |
-
logger.info("Entity %s search result: %s response: %s", entity_name, str(search_results), response)
|
42 |
-
|
43 |
-
return entity_name, response
|
44 |
-
|
45 |
-
async def search_wikipedia(llm_client: OpenAIModel,
|
46 |
-
wiki_search_client: WikiSearch,
|
47 |
-
knowledge_graph_instance: BaseGraphStorage,) -> dict:
|
48 |
-
"""
|
49 |
-
Search wikipedia for entities
|
50 |
-
|
51 |
-
:param llm_client: LLM model
|
52 |
-
:param wiki_search_client: wiki search client
|
53 |
-
:param knowledge_graph_instance: knowledge graph instance
|
54 |
-
:return: nodes with search results
|
55 |
-
"""
|
56 |
-
|
57 |
-
|
58 |
-
nodes = await knowledge_graph_instance.get_all_nodes()
|
59 |
-
nodes = list(nodes)
|
60 |
-
wiki_data = {}
|
61 |
-
|
62 |
-
tasks = [
|
63 |
-
_process_single_entity(node[0].strip('"'), node[1]["description"], llm_client, wiki_search_client)
|
64 |
-
for node in nodes
|
65 |
-
]
|
66 |
-
|
67 |
-
for task in asyncio.as_completed(tasks):
|
68 |
-
result = await task
|
69 |
-
wiki_data[result[0]] = result[1]
|
70 |
-
|
71 |
-
return wiki_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/operators/traverse_graph.py
CHANGED
@@ -1,49 +1,67 @@
|
|
1 |
import asyncio
|
2 |
-
import gradio as gr
|
3 |
|
|
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
-
from graphgen.models import
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
sem = asyncio.Semaphore(1000)
|
|
|
18 |
async def handle_edge(edge: tuple) -> tuple:
|
19 |
async with sem:
|
20 |
-
if
|
21 |
-
edge[2][
|
22 |
-
await asyncio.get_event_loop().run_in_executor(
|
23 |
-
|
24 |
-
|
|
|
25 |
return edge
|
26 |
|
27 |
async def handle_node(node: dict) -> dict:
|
28 |
async with sem:
|
29 |
-
if
|
30 |
-
node[1][
|
31 |
-
await asyncio.get_event_loop().run_in_executor(
|
32 |
-
|
33 |
-
|
|
|
34 |
return node
|
35 |
|
36 |
new_edges = []
|
37 |
new_nodes = []
|
38 |
|
39 |
-
for result in tqdm_async(
|
40 |
-
|
|
|
|
|
|
|
41 |
new_edge = await result
|
42 |
await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2])
|
43 |
new_edges.append(new_edge)
|
44 |
|
45 |
-
for result in tqdm_async(
|
46 |
-
|
|
|
|
|
|
|
47 |
new_node = await result
|
48 |
await graph_storage.update_node(new_node[0], new_node[1])
|
49 |
new_nodes.append(new_node)
|
@@ -51,60 +69,75 @@ async def _pre_tokenize(graph_storage: NetworkXStorage,
|
|
51 |
await graph_storage.index_done_callback()
|
52 |
return new_edges, new_nodes
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
entities = [
|
60 |
-
f"{_process_node['node_id']}: {_process_node['description']}"
|
|
|
61 |
]
|
62 |
relations = [
|
63 |
f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
|
64 |
for _process_edge in _process_edges
|
65 |
]
|
66 |
|
67 |
-
entities_str = "\n".join(
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
if add_context:
|
72 |
-
original_ids =
|
73 |
-
|
|
|
74 |
|
75 |
original_ids = list(set(original_ids))
|
76 |
original_text = await text_chunks_storage.get_by_ids(original_ids)
|
77 |
-
original_text = "\n".join(
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
prompt = ANSWER_REPHRASING_PROMPT[language][
|
80 |
language=language,
|
81 |
original_text=original_text,
|
82 |
entities=entities_str,
|
83 |
-
relationships=relations_str
|
84 |
)
|
85 |
return prompt
|
86 |
|
87 |
-
prompt = ANSWER_REPHRASING_PROMPT[language][
|
88 |
-
language=language,
|
89 |
-
entities=entities_str,
|
90 |
-
relationships=relations_str
|
91 |
)
|
92 |
return prompt
|
93 |
|
94 |
-
def get_loss_tercile(losses: list) -> (float, float):
|
95 |
-
losses = sorted(losses)
|
96 |
-
q1_index = int(len(losses) * (1 / 3))
|
97 |
-
q2_index = int(len(losses) * (2 / 3))
|
98 |
-
|
99 |
-
return losses[q1_index], losses[q2_index]
|
100 |
|
101 |
def get_average_loss(batch: tuple, loss_strategy: str) -> float:
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
def _post_process_synthetic_data(data):
|
110 |
block = data.split("\n\n")
|
@@ -113,26 +146,18 @@ def _post_process_synthetic_data(data):
|
|
113 |
if "Question:" in line and "Answer:" in line:
|
114 |
question = line.split("Question:")[1].split("Answer:")[0].strip()
|
115 |
answer = line.split("Answer:")[1].strip()
|
116 |
-
qas.append({
|
117 |
-
"question": question,
|
118 |
-
"answer": answer
|
119 |
-
})
|
120 |
elif "问题:" in line and "答案:" in line:
|
121 |
question = line.split("问题:")[1].split("答案:")[0].strip()
|
122 |
answer = line.split("答案:")[1].strip()
|
123 |
-
qas.append({
|
124 |
-
"question": question,
|
125 |
-
"answer": answer
|
126 |
-
})
|
127 |
elif "问题:" in line and "回答:" in line:
|
128 |
question = line.split("问题:")[1].split("回答:")[0].strip()
|
129 |
answer = line.split("回答:")[1].strip()
|
130 |
-
qas.append({
|
131 |
-
"question": question,
|
132 |
-
"answer": answer
|
133 |
-
})
|
134 |
return qas
|
135 |
|
|
|
136 |
async def traverse_graph_by_edge(
|
137 |
llm_client: OpenAIModel,
|
138 |
tokenizer: Tokenizer,
|
@@ -140,7 +165,7 @@ async def traverse_graph_by_edge(
|
|
140 |
traverse_strategy: TraverseStrategy,
|
141 |
text_chunks_storage: JsonKVStorage,
|
142 |
progress_bar: gr.Progress = None,
|
143 |
-
max_concurrent: int = 1000
|
144 |
) -> dict:
|
145 |
"""
|
146 |
Traverse the graph
|
@@ -158,28 +183,24 @@ async def traverse_graph_by_edge(
|
|
158 |
semaphore = asyncio.Semaphore(max_concurrent)
|
159 |
|
160 |
async def _process_nodes_and_edges(
|
161 |
-
|
162 |
-
|
163 |
) -> str:
|
164 |
prompt = await _construct_rephrasing_prompt(
|
165 |
-
_process_nodes,
|
166 |
-
_process_edges,
|
167 |
-
text_chunks_storage,
|
168 |
-
add_context = False
|
169 |
)
|
170 |
context = await llm_client.generate_answer(prompt)
|
171 |
|
172 |
# post-process the context
|
173 |
if context.startswith("Rephrased Text:"):
|
174 |
-
context = context[len("Rephrased Text:"):].strip()
|
175 |
elif context.startswith("重述文本:"):
|
176 |
-
context = context[len("重述文本:"):].strip()
|
177 |
|
178 |
return context
|
179 |
|
180 |
async def _process_single_batch(
|
181 |
-
_process_batch: tuple,
|
182 |
-
question_type: str = "single"
|
183 |
) -> dict:
|
184 |
async with semaphore:
|
185 |
context = await _process_nodes_and_edges(
|
@@ -188,21 +209,26 @@ async def traverse_graph_by_edge(
|
|
188 |
)
|
189 |
|
190 |
language = "Chinese" if detect_main_language(context) == "zh" else "English"
|
191 |
-
pre_length = sum(node[
|
192 |
-
|
|
|
193 |
|
194 |
if question_type == "single":
|
195 |
question = await llm_client.generate_answer(
|
196 |
-
QUESTION_GENERATION_PROMPT[language][
|
197 |
answer=context
|
198 |
)
|
199 |
)
|
200 |
if question.startswith("Question:"):
|
201 |
-
question = question[len("Question:"):].strip()
|
202 |
elif question.startswith("问题:"):
|
203 |
-
question = question[len("问题:"):].strip()
|
204 |
|
205 |
-
logger.info(
|
|
|
|
|
|
|
|
|
206 |
logger.info("Pre-length: %s", pre_length)
|
207 |
logger.info("Question: %s", question)
|
208 |
logger.info("Answer: %s", context)
|
@@ -211,12 +237,14 @@ async def traverse_graph_by_edge(
|
|
211 |
compute_content_hash(context): {
|
212 |
"question": question,
|
213 |
"answer": context,
|
214 |
-
"loss": get_average_loss(
|
|
|
|
|
215 |
}
|
216 |
}
|
217 |
|
218 |
content = await llm_client.generate_answer(
|
219 |
-
QUESTION_GENERATION_PROMPT[language][
|
220 |
doc=context
|
221 |
)
|
222 |
)
|
@@ -224,19 +252,27 @@ async def traverse_graph_by_edge(
|
|
224 |
|
225 |
if len(qas) == 0:
|
226 |
print(content)
|
227 |
-
logger.error(
|
|
|
|
|
228 |
return {}
|
229 |
|
230 |
final_results = {}
|
231 |
-
logger.info(
|
|
|
|
|
|
|
|
|
232 |
logger.info("Pre-length: %s", pre_length)
|
233 |
for qa in qas:
|
234 |
-
logger.info("Question: %s", qa[
|
235 |
-
logger.info("Answer: %s", qa[
|
236 |
-
final_results[compute_content_hash(qa[
|
237 |
-
"question": qa[
|
238 |
-
"answer": qa[
|
239 |
-
"loss": get_average_loss(
|
|
|
|
|
240 |
}
|
241 |
return final_results
|
242 |
|
@@ -247,22 +283,25 @@ async def traverse_graph_by_edge(
|
|
247 |
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
|
248 |
|
249 |
processing_batches = await get_batches_with_strategy(
|
250 |
-
nodes,
|
251 |
-
edges,
|
252 |
-
graph_storage,
|
253 |
-
traverse_strategy
|
254 |
)
|
255 |
|
256 |
-
for result in tqdm_async(
|
257 |
-
|
258 |
-
|
|
|
|
|
|
|
|
|
259 |
try:
|
260 |
if progress_bar is not None:
|
261 |
-
progress_bar(
|
|
|
|
|
262 |
results.update(await result)
|
263 |
if progress_bar is not None and len(results) == len(processing_batches):
|
264 |
progress_bar(1, desc="[4/4]Generating QAs")
|
265 |
-
except Exception as e:
|
266 |
logger.error("Error occurred while generating QA: %s", e)
|
267 |
|
268 |
return results
|
@@ -275,7 +314,7 @@ async def traverse_graph_atomically(
|
|
275 |
traverse_strategy: TraverseStrategy,
|
276 |
text_chunks_storage: JsonKVStorage,
|
277 |
progress_bar: gr.Progress = None,
|
278 |
-
max_concurrent: int = 1000
|
279 |
) -> dict:
|
280 |
"""
|
281 |
Traverse the graph atomicly
|
@@ -292,22 +331,21 @@ async def traverse_graph_atomically(
|
|
292 |
assert traverse_strategy.qa_form == "atomic"
|
293 |
|
294 |
semaphore = asyncio.Semaphore(max_concurrent)
|
295 |
-
|
296 |
-
|
297 |
-
):
|
298 |
if len(node_or_edge) == 2:
|
299 |
-
des = node_or_edge[0] + ": " + node_or_edge[1][
|
300 |
-
loss = node_or_edge[1][
|
301 |
else:
|
302 |
-
des = node_or_edge[2][
|
303 |
-
loss = node_or_edge[2][
|
304 |
|
305 |
async with semaphore:
|
306 |
try:
|
307 |
language = "Chinese" if detect_main_language(des) == "zh" else "English"
|
308 |
|
309 |
qa = await llm_client.generate_answer(
|
310 |
-
QUESTION_GENERATION_PROMPT[language][
|
311 |
doc=des
|
312 |
)
|
313 |
)
|
@@ -321,8 +359,8 @@ async def traverse_graph_atomically(
|
|
321 |
else:
|
322 |
return {}
|
323 |
|
324 |
-
question = question.strip("
|
325 |
-
answer = answer.strip("
|
326 |
|
327 |
logger.info("Question: %s", question)
|
328 |
logger.info("Answer: %s", answer)
|
@@ -330,10 +368,10 @@ async def traverse_graph_atomically(
|
|
330 |
compute_content_hash(question): {
|
331 |
"question": question,
|
332 |
"answer": answer,
|
333 |
-
"loss": loss
|
334 |
}
|
335 |
}
|
336 |
-
except Exception as e:
|
337 |
logger.error("Error occurred while generating question: %s", e)
|
338 |
return {}
|
339 |
|
@@ -345,24 +383,26 @@ async def traverse_graph_atomically(
|
|
345 |
|
346 |
tasks = []
|
347 |
for node in nodes:
|
348 |
-
if "<SEP>" in node[1][
|
349 |
-
description_list = node[1][
|
350 |
for item in description_list:
|
351 |
-
tasks.append((node[0], {"description": item,
|
352 |
else:
|
353 |
tasks.append((node[0], node[1]))
|
354 |
for edge in edges:
|
355 |
-
if "<SEP>" in edge[2][
|
356 |
-
description_list = edge[2][
|
357 |
for item in description_list:
|
358 |
-
tasks.append(
|
|
|
|
|
359 |
else:
|
360 |
tasks.append((edge[0], edge[1], edge[2]))
|
361 |
|
362 |
for result in tqdm_async(
|
363 |
asyncio.as_completed([_generate_question(task) for task in tasks]),
|
364 |
total=len(tasks),
|
365 |
-
desc="[4/4]Generating QAs"
|
366 |
):
|
367 |
try:
|
368 |
if progress_bar is not None:
|
@@ -370,10 +410,11 @@ async def traverse_graph_atomically(
|
|
370 |
results.update(await result)
|
371 |
if progress_bar is not None and len(results) == len(tasks):
|
372 |
progress_bar(1, desc="[4/4]Generating QAs")
|
373 |
-
except Exception as e:
|
374 |
logger.error("Error occurred while generating QA: %s", e)
|
375 |
return results
|
376 |
|
|
|
377 |
async def traverse_graph_for_multi_hop(
|
378 |
llm_client: OpenAIModel,
|
379 |
tokenizer: Tokenizer,
|
@@ -381,7 +422,7 @@ async def traverse_graph_for_multi_hop(
|
|
381 |
traverse_strategy: TraverseStrategy,
|
382 |
text_chunks_storage: JsonKVStorage,
|
383 |
progress_bar: gr.Progress = None,
|
384 |
-
max_concurrent: int = 1000
|
385 |
) -> dict:
|
386 |
"""
|
387 |
Traverse the graph for multi-hop
|
@@ -395,8 +436,6 @@ async def traverse_graph_for_multi_hop(
|
|
395 |
:param max_concurrent
|
396 |
:return: question and answer
|
397 |
"""
|
398 |
-
assert traverse_strategy.qa_form == "multi_hop"
|
399 |
-
|
400 |
semaphore = asyncio.Semaphore(max_concurrent)
|
401 |
|
402 |
results = {}
|
@@ -406,24 +445,24 @@ async def traverse_graph_for_multi_hop(
|
|
406 |
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
|
407 |
|
408 |
processing_batches = await get_batches_with_strategy(
|
409 |
-
nodes,
|
410 |
-
edges,
|
411 |
-
graph_storage,
|
412 |
-
traverse_strategy
|
413 |
)
|
414 |
|
415 |
-
async def _process_single_batch(
|
416 |
-
_process_batch: tuple
|
417 |
-
) -> dict:
|
418 |
async with semaphore:
|
419 |
try:
|
420 |
-
language =
|
|
|
|
|
|
|
|
|
421 |
|
422 |
_process_nodes = _process_batch[0]
|
423 |
_process_edges = _process_batch[1]
|
424 |
|
425 |
entities = [
|
426 |
-
f"{_process_node['node_id']}: {_process_node['description']}"
|
|
|
427 |
]
|
428 |
|
429 |
relations = [
|
@@ -431,12 +470,18 @@ async def traverse_graph_for_multi_hop(
|
|
431 |
for _process_edge in _process_edges
|
432 |
]
|
433 |
|
434 |
-
entities_str = "\n".join(
|
435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
|
437 |
prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
|
438 |
-
entities=entities_str,
|
439 |
-
relationships=relations_str
|
440 |
)
|
441 |
|
442 |
context = await llm_client.generate_answer(prompt)
|
@@ -451,8 +496,8 @@ async def traverse_graph_for_multi_hop(
|
|
451 |
else:
|
452 |
return {}
|
453 |
|
454 |
-
question = question.strip("
|
455 |
-
answer = answer.strip("
|
456 |
|
457 |
logger.info("Question: %s", question)
|
458 |
logger.info("Answer: %s", answer)
|
@@ -461,25 +506,31 @@ async def traverse_graph_for_multi_hop(
|
|
461 |
compute_content_hash(question): {
|
462 |
"question": question,
|
463 |
"answer": answer,
|
464 |
-
"loss": get_average_loss(
|
|
|
|
|
465 |
}
|
466 |
}
|
467 |
|
468 |
-
except Exception as e:
|
469 |
logger.error("Error occurred while processing batch: %s", e)
|
470 |
return {}
|
471 |
|
472 |
async for result in tqdm_async(
|
473 |
-
asyncio.as_completed(
|
|
|
|
|
474 |
total=len(processing_batches),
|
475 |
-
desc="[4/4]Generating QAs"
|
476 |
):
|
477 |
try:
|
478 |
if progress_bar is not None:
|
479 |
-
progress_bar(
|
|
|
|
|
480 |
results.update(await result)
|
481 |
if progress_bar is not None and len(results) == len(processing_batches):
|
482 |
progress_bar(1, desc="[4/4]Generating QAs")
|
483 |
-
except Exception as e:
|
484 |
logger.error("Error occurred while generating QA: %s", e)
|
485 |
return results
|
|
|
1 |
import asyncio
|
|
|
2 |
|
3 |
+
import gradio as gr
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
+
from graphgen.models import (
|
7 |
+
JsonKVStorage,
|
8 |
+
NetworkXStorage,
|
9 |
+
OpenAIModel,
|
10 |
+
Tokenizer,
|
11 |
+
TraverseStrategy,
|
12 |
+
)
|
13 |
+
from graphgen.operators.kg.split_kg import get_batches_with_strategy
|
14 |
+
from graphgen.templates import (
|
15 |
+
ANSWER_REPHRASING_PROMPT,
|
16 |
+
MULTI_HOP_GENERATION_PROMPT,
|
17 |
+
QUESTION_GENERATION_PROMPT,
|
18 |
+
)
|
19 |
+
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
20 |
+
|
21 |
+
|
22 |
+
async def _pre_tokenize(
|
23 |
+
graph_storage: NetworkXStorage, tokenizer: Tokenizer, edges: list, nodes: list
|
24 |
+
) -> tuple:
|
25 |
|
26 |
sem = asyncio.Semaphore(1000)
|
27 |
+
|
28 |
async def handle_edge(edge: tuple) -> tuple:
|
29 |
async with sem:
|
30 |
+
if "length" not in edge[2]:
|
31 |
+
edge[2]["length"] = len(
|
32 |
+
await asyncio.get_event_loop().run_in_executor(
|
33 |
+
None, tokenizer.encode_string, edge[2]["description"]
|
34 |
+
)
|
35 |
+
)
|
36 |
return edge
|
37 |
|
38 |
async def handle_node(node: dict) -> dict:
|
39 |
async with sem:
|
40 |
+
if "length" not in node[1]:
|
41 |
+
node[1]["length"] = len(
|
42 |
+
await asyncio.get_event_loop().run_in_executor(
|
43 |
+
None, tokenizer.encode_string, node[1]["description"]
|
44 |
+
)
|
45 |
+
)
|
46 |
return node
|
47 |
|
48 |
new_edges = []
|
49 |
new_nodes = []
|
50 |
|
51 |
+
for result in tqdm_async(
|
52 |
+
asyncio.as_completed([handle_edge(edge) for edge in edges]),
|
53 |
+
total=len(edges),
|
54 |
+
desc="Pre-tokenizing edges",
|
55 |
+
):
|
56 |
new_edge = await result
|
57 |
await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2])
|
58 |
new_edges.append(new_edge)
|
59 |
|
60 |
+
for result in tqdm_async(
|
61 |
+
asyncio.as_completed([handle_node(node) for node in nodes]),
|
62 |
+
total=len(nodes),
|
63 |
+
desc="Pre-tokenizing nodes",
|
64 |
+
):
|
65 |
new_node = await result
|
66 |
await graph_storage.update_node(new_node[0], new_node[1])
|
67 |
new_nodes.append(new_node)
|
|
|
69 |
await graph_storage.index_done_callback()
|
70 |
return new_edges, new_nodes
|
71 |
|
72 |
+
|
73 |
+
async def _construct_rephrasing_prompt(
|
74 |
+
_process_nodes: list,
|
75 |
+
_process_edges: list,
|
76 |
+
text_chunks_storage: JsonKVStorage,
|
77 |
+
add_context: bool = False,
|
78 |
+
) -> str:
|
79 |
entities = [
|
80 |
+
f"{_process_node['node_id']}: {_process_node['description']}"
|
81 |
+
for _process_node in _process_nodes
|
82 |
]
|
83 |
relations = [
|
84 |
f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
|
85 |
for _process_edge in _process_edges
|
86 |
]
|
87 |
|
88 |
+
entities_str = "\n".join(
|
89 |
+
[f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
|
90 |
+
)
|
91 |
+
relations_str = "\n".join(
|
92 |
+
[f"{index + 1}. {relation}" for index, relation in enumerate(relations)]
|
93 |
+
)
|
94 |
+
language = (
|
95 |
+
"Chinese"
|
96 |
+
if detect_main_language(entities_str + relations_str) == "zh"
|
97 |
+
else "English"
|
98 |
+
)
|
99 |
|
100 |
if add_context:
|
101 |
+
original_ids = [
|
102 |
+
node["source_id"].split("<SEP>")[0] for node in _process_nodes
|
103 |
+
] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
|
104 |
|
105 |
original_ids = list(set(original_ids))
|
106 |
original_text = await text_chunks_storage.get_by_ids(original_ids)
|
107 |
+
original_text = "\n".join(
|
108 |
+
[
|
109 |
+
f"{index + 1}. {text['content']}"
|
110 |
+
for index, text in enumerate(original_text)
|
111 |
+
]
|
112 |
+
)
|
113 |
|
114 |
+
prompt = ANSWER_REPHRASING_PROMPT[language]["CONTEXT_TEMPLATE"].format(
|
115 |
language=language,
|
116 |
original_text=original_text,
|
117 |
entities=entities_str,
|
118 |
+
relationships=relations_str,
|
119 |
)
|
120 |
return prompt
|
121 |
|
122 |
+
prompt = ANSWER_REPHRASING_PROMPT[language]["TEMPLATE"].format(
|
123 |
+
language=language, entities=entities_str, relationships=relations_str
|
|
|
|
|
124 |
)
|
125 |
return prompt
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
def get_average_loss(batch: tuple, loss_strategy: str) -> float:
|
129 |
+
try:
|
130 |
+
if loss_strategy == "only_edge":
|
131 |
+
return sum(edge[2]["loss"] for edge in batch[1]) / len(batch[1])
|
132 |
+
if loss_strategy == "both":
|
133 |
+
return sum(edge[2]["loss"] for edge in batch[1]) + sum(
|
134 |
+
node["loss"] for node in batch[0]
|
135 |
+
) / (len(batch[0]) + len(batch[1]))
|
136 |
+
raise ValueError("Invalid loss strategy")
|
137 |
+
except Exception as e: # pylint: disable=broad-except
|
138 |
+
logger.error("Error calculating average loss: %s", e)
|
139 |
+
return -1.0
|
140 |
+
|
141 |
|
142 |
def _post_process_synthetic_data(data):
|
143 |
block = data.split("\n\n")
|
|
|
146 |
if "Question:" in line and "Answer:" in line:
|
147 |
question = line.split("Question:")[1].split("Answer:")[0].strip()
|
148 |
answer = line.split("Answer:")[1].strip()
|
149 |
+
qas.append({"question": question, "answer": answer})
|
|
|
|
|
|
|
150 |
elif "问题:" in line and "答案:" in line:
|
151 |
question = line.split("问题:")[1].split("答案:")[0].strip()
|
152 |
answer = line.split("答案:")[1].strip()
|
153 |
+
qas.append({"question": question, "answer": answer})
|
|
|
|
|
|
|
154 |
elif "问题:" in line and "回答:" in line:
|
155 |
question = line.split("问题:")[1].split("回答:")[0].strip()
|
156 |
answer = line.split("回答:")[1].strip()
|
157 |
+
qas.append({"question": question, "answer": answer})
|
|
|
|
|
|
|
158 |
return qas
|
159 |
|
160 |
+
|
161 |
async def traverse_graph_by_edge(
|
162 |
llm_client: OpenAIModel,
|
163 |
tokenizer: Tokenizer,
|
|
|
165 |
traverse_strategy: TraverseStrategy,
|
166 |
text_chunks_storage: JsonKVStorage,
|
167 |
progress_bar: gr.Progress = None,
|
168 |
+
max_concurrent: int = 1000,
|
169 |
) -> dict:
|
170 |
"""
|
171 |
Traverse the graph
|
|
|
183 |
semaphore = asyncio.Semaphore(max_concurrent)
|
184 |
|
185 |
async def _process_nodes_and_edges(
|
186 |
+
_process_nodes: list,
|
187 |
+
_process_edges: list,
|
188 |
) -> str:
|
189 |
prompt = await _construct_rephrasing_prompt(
|
190 |
+
_process_nodes, _process_edges, text_chunks_storage, add_context=False
|
|
|
|
|
|
|
191 |
)
|
192 |
context = await llm_client.generate_answer(prompt)
|
193 |
|
194 |
# post-process the context
|
195 |
if context.startswith("Rephrased Text:"):
|
196 |
+
context = context[len("Rephrased Text:") :].strip()
|
197 |
elif context.startswith("重述文本:"):
|
198 |
+
context = context[len("重述文本:") :].strip()
|
199 |
|
200 |
return context
|
201 |
|
202 |
async def _process_single_batch(
|
203 |
+
_process_batch: tuple, question_type: str = "single"
|
|
|
204 |
) -> dict:
|
205 |
async with semaphore:
|
206 |
context = await _process_nodes_and_edges(
|
|
|
209 |
)
|
210 |
|
211 |
language = "Chinese" if detect_main_language(context) == "zh" else "English"
|
212 |
+
pre_length = sum(node["length"] for node in _process_batch[0]) + sum(
|
213 |
+
edge[2]["length"] for edge in _process_batch[1]
|
214 |
+
)
|
215 |
|
216 |
if question_type == "single":
|
217 |
question = await llm_client.generate_answer(
|
218 |
+
QUESTION_GENERATION_PROMPT[language]["SINGLE_TEMPLATE"].format(
|
219 |
answer=context
|
220 |
)
|
221 |
)
|
222 |
if question.startswith("Question:"):
|
223 |
+
question = question[len("Question:") :].strip()
|
224 |
elif question.startswith("问题:"):
|
225 |
+
question = question[len("问题:") :].strip()
|
226 |
|
227 |
+
logger.info(
|
228 |
+
"%d nodes and %d edges processed",
|
229 |
+
len(_process_batch[0]),
|
230 |
+
len(_process_batch[1]),
|
231 |
+
)
|
232 |
logger.info("Pre-length: %s", pre_length)
|
233 |
logger.info("Question: %s", question)
|
234 |
logger.info("Answer: %s", context)
|
|
|
237 |
compute_content_hash(context): {
|
238 |
"question": question,
|
239 |
"answer": context,
|
240 |
+
"loss": get_average_loss(
|
241 |
+
_process_batch, traverse_strategy.loss_strategy
|
242 |
+
),
|
243 |
}
|
244 |
}
|
245 |
|
246 |
content = await llm_client.generate_answer(
|
247 |
+
QUESTION_GENERATION_PROMPT[language]["MULTI_TEMPLATE"].format(
|
248 |
doc=context
|
249 |
)
|
250 |
)
|
|
|
252 |
|
253 |
if len(qas) == 0:
|
254 |
print(content)
|
255 |
+
logger.error(
|
256 |
+
"Error occurred while processing batch, question or answer is None"
|
257 |
+
)
|
258 |
return {}
|
259 |
|
260 |
final_results = {}
|
261 |
+
logger.info(
|
262 |
+
"%d nodes and %d edges processed",
|
263 |
+
len(_process_batch[0]),
|
264 |
+
len(_process_batch[1]),
|
265 |
+
)
|
266 |
logger.info("Pre-length: %s", pre_length)
|
267 |
for qa in qas:
|
268 |
+
logger.info("Question: %s", qa["question"])
|
269 |
+
logger.info("Answer: %s", qa["answer"])
|
270 |
+
final_results[compute_content_hash(qa["question"])] = {
|
271 |
+
"question": qa["question"],
|
272 |
+
"answer": qa["answer"],
|
273 |
+
"loss": get_average_loss(
|
274 |
+
_process_batch, traverse_strategy.loss_strategy
|
275 |
+
),
|
276 |
}
|
277 |
return final_results
|
278 |
|
|
|
283 |
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
|
284 |
|
285 |
processing_batches = await get_batches_with_strategy(
|
286 |
+
nodes, edges, graph_storage, traverse_strategy
|
|
|
|
|
|
|
287 |
)
|
288 |
|
289 |
+
for result in tqdm_async(
|
290 |
+
asyncio.as_completed(
|
291 |
+
[_process_single_batch(batch) for batch in processing_batches]
|
292 |
+
),
|
293 |
+
total=len(processing_batches),
|
294 |
+
desc="[4/4]Generating QAs",
|
295 |
+
):
|
296 |
try:
|
297 |
if progress_bar is not None:
|
298 |
+
progress_bar(
|
299 |
+
len(results) / len(processing_batches), desc="[4/4]Generating QAs"
|
300 |
+
)
|
301 |
results.update(await result)
|
302 |
if progress_bar is not None and len(results) == len(processing_batches):
|
303 |
progress_bar(1, desc="[4/4]Generating QAs")
|
304 |
+
except Exception as e: # pylint: disable=broad-except
|
305 |
logger.error("Error occurred while generating QA: %s", e)
|
306 |
|
307 |
return results
|
|
|
314 |
traverse_strategy: TraverseStrategy,
|
315 |
text_chunks_storage: JsonKVStorage,
|
316 |
progress_bar: gr.Progress = None,
|
317 |
+
max_concurrent: int = 1000,
|
318 |
) -> dict:
|
319 |
"""
|
320 |
Traverse the graph atomicly
|
|
|
331 |
assert traverse_strategy.qa_form == "atomic"
|
332 |
|
333 |
semaphore = asyncio.Semaphore(max_concurrent)
|
334 |
+
|
335 |
+
async def _generate_question(node_or_edge: tuple):
|
|
|
336 |
if len(node_or_edge) == 2:
|
337 |
+
des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
|
338 |
+
loss = node_or_edge[1]["loss"]
|
339 |
else:
|
340 |
+
des = node_or_edge[2]["description"]
|
341 |
+
loss = node_or_edge[2]["loss"]
|
342 |
|
343 |
async with semaphore:
|
344 |
try:
|
345 |
language = "Chinese" if detect_main_language(des) == "zh" else "English"
|
346 |
|
347 |
qa = await llm_client.generate_answer(
|
348 |
+
QUESTION_GENERATION_PROMPT[language]["SINGLE_QA_TEMPLATE"].format(
|
349 |
doc=des
|
350 |
)
|
351 |
)
|
|
|
359 |
else:
|
360 |
return {}
|
361 |
|
362 |
+
question = question.strip('"')
|
363 |
+
answer = answer.strip('"')
|
364 |
|
365 |
logger.info("Question: %s", question)
|
366 |
logger.info("Answer: %s", answer)
|
|
|
368 |
compute_content_hash(question): {
|
369 |
"question": question,
|
370 |
"answer": answer,
|
371 |
+
"loss": loss,
|
372 |
}
|
373 |
}
|
374 |
+
except Exception as e: # pylint: disable=broad-except
|
375 |
logger.error("Error occurred while generating question: %s", e)
|
376 |
return {}
|
377 |
|
|
|
383 |
|
384 |
tasks = []
|
385 |
for node in nodes:
|
386 |
+
if "<SEP>" in node[1]["description"]:
|
387 |
+
description_list = node[1]["description"].split("<SEP>")
|
388 |
for item in description_list:
|
389 |
+
tasks.append((node[0], {"description": item, "loss": node[1]["loss"]}))
|
390 |
else:
|
391 |
tasks.append((node[0], node[1]))
|
392 |
for edge in edges:
|
393 |
+
if "<SEP>" in edge[2]["description"]:
|
394 |
+
description_list = edge[2]["description"].split("<SEP>")
|
395 |
for item in description_list:
|
396 |
+
tasks.append(
|
397 |
+
(edge[0], edge[1], {"description": item, "loss": edge[2]["loss"]})
|
398 |
+
)
|
399 |
else:
|
400 |
tasks.append((edge[0], edge[1], edge[2]))
|
401 |
|
402 |
for result in tqdm_async(
|
403 |
asyncio.as_completed([_generate_question(task) for task in tasks]),
|
404 |
total=len(tasks),
|
405 |
+
desc="[4/4]Generating QAs",
|
406 |
):
|
407 |
try:
|
408 |
if progress_bar is not None:
|
|
|
410 |
results.update(await result)
|
411 |
if progress_bar is not None and len(results) == len(tasks):
|
412 |
progress_bar(1, desc="[4/4]Generating QAs")
|
413 |
+
except Exception as e: # pylint: disable=broad-except
|
414 |
logger.error("Error occurred while generating QA: %s", e)
|
415 |
return results
|
416 |
|
417 |
+
|
418 |
async def traverse_graph_for_multi_hop(
|
419 |
llm_client: OpenAIModel,
|
420 |
tokenizer: Tokenizer,
|
|
|
422 |
traverse_strategy: TraverseStrategy,
|
423 |
text_chunks_storage: JsonKVStorage,
|
424 |
progress_bar: gr.Progress = None,
|
425 |
+
max_concurrent: int = 1000,
|
426 |
) -> dict:
|
427 |
"""
|
428 |
Traverse the graph for multi-hop
|
|
|
436 |
:param max_concurrent
|
437 |
:return: question and answer
|
438 |
"""
|
|
|
|
|
439 |
semaphore = asyncio.Semaphore(max_concurrent)
|
440 |
|
441 |
results = {}
|
|
|
445 |
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
|
446 |
|
447 |
processing_batches = await get_batches_with_strategy(
|
448 |
+
nodes, edges, graph_storage, traverse_strategy
|
|
|
|
|
|
|
449 |
)
|
450 |
|
451 |
+
async def _process_single_batch(_process_batch: tuple) -> dict:
|
|
|
|
|
452 |
async with semaphore:
|
453 |
try:
|
454 |
+
language = (
|
455 |
+
"Chinese"
|
456 |
+
if detect_main_language(_process_batch[0][0]["description"]) == "zh"
|
457 |
+
else "English"
|
458 |
+
)
|
459 |
|
460 |
_process_nodes = _process_batch[0]
|
461 |
_process_edges = _process_batch[1]
|
462 |
|
463 |
entities = [
|
464 |
+
f"{_process_node['node_id']}: {_process_node['description']}"
|
465 |
+
for _process_node in _process_nodes
|
466 |
]
|
467 |
|
468 |
relations = [
|
|
|
470 |
for _process_edge in _process_edges
|
471 |
]
|
472 |
|
473 |
+
entities_str = "\n".join(
|
474 |
+
[f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
|
475 |
+
)
|
476 |
+
relations_str = "\n".join(
|
477 |
+
[
|
478 |
+
f"{index + 1}. {relation}"
|
479 |
+
for index, relation in enumerate(relations)
|
480 |
+
]
|
481 |
+
)
|
482 |
|
483 |
prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
|
484 |
+
entities=entities_str, relationships=relations_str
|
|
|
485 |
)
|
486 |
|
487 |
context = await llm_client.generate_answer(prompt)
|
|
|
496 |
else:
|
497 |
return {}
|
498 |
|
499 |
+
question = question.strip('"')
|
500 |
+
answer = answer.strip('"')
|
501 |
|
502 |
logger.info("Question: %s", question)
|
503 |
logger.info("Answer: %s", answer)
|
|
|
506 |
compute_content_hash(question): {
|
507 |
"question": question,
|
508 |
"answer": answer,
|
509 |
+
"loss": get_average_loss(
|
510 |
+
_process_batch, traverse_strategy.loss_strategy
|
511 |
+
),
|
512 |
}
|
513 |
}
|
514 |
|
515 |
+
except Exception as e: # pylint: disable=broad-except
|
516 |
logger.error("Error occurred while processing batch: %s", e)
|
517 |
return {}
|
518 |
|
519 |
async for result in tqdm_async(
|
520 |
+
asyncio.as_completed(
|
521 |
+
[_process_single_batch(batch) for batch in processing_batches]
|
522 |
+
),
|
523 |
total=len(processing_batches),
|
524 |
+
desc="[4/4]Generating QAs",
|
525 |
):
|
526 |
try:
|
527 |
if progress_bar is not None:
|
528 |
+
progress_bar(
|
529 |
+
len(results) / len(processing_batches), desc="[4/4]Generating QAs"
|
530 |
+
)
|
531 |
results.update(await result)
|
532 |
if progress_bar is not None and len(results) == len(processing_batches):
|
533 |
progress_bar(1, desc="[4/4]Generating QAs")
|
534 |
+
except Exception as e: # pylint: disable=broad-except
|
535 |
logger.error("Error occurred while generating QA: %s", e)
|
536 |
return results
|