github-actions[bot] commited on
Commit
fb9c306
·
1 Parent(s): e453a65

Auto-sync from demo at Thu Aug 28 09:22:58 UTC 2025

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env.example +0 -6
  2. .gitattributes +0 -35
  3. .gitignore +0 -179
  4. README.md +0 -14
  5. app.py +281 -214
  6. graphgen/configs/README.md +1 -0
  7. graphgen/configs/aggregated_config.yaml +21 -0
  8. graphgen/configs/atomic_config.yaml +21 -0
  9. graphgen/configs/config.yaml.example +0 -16
  10. graphgen/configs/cot_config.yaml +13 -0
  11. graphgen/configs/graphgen_config.yaml +0 -16
  12. graphgen/configs/multi_hop_config.yaml +21 -0
  13. graphgen/generate.py +64 -62
  14. graphgen/graphgen.py +232 -97
  15. graphgen/models/__init__.py +18 -14
  16. graphgen/models/community/__init__.py +0 -0
  17. graphgen/models/community/community_detector.py +95 -0
  18. graphgen/models/llm/openai_model.py +58 -33
  19. graphgen/models/search/db/__init__.py +0 -0
  20. graphgen/models/search/db/uniprot_search.py +64 -0
  21. graphgen/models/search/kg/__init__.py +0 -0
  22. graphgen/models/search/{wiki_search.py → kg/wiki_search.py} +4 -3
  23. graphgen/models/search/web/__init__.py +0 -0
  24. graphgen/models/search/web/bing_search.py +43 -0
  25. graphgen/models/search/web/google_search.py +45 -0
  26. graphgen/models/storage/base_storage.py +25 -4
  27. graphgen/models/storage/json_storage.py +39 -3
  28. graphgen/models/vis/__init__.py +0 -0
  29. graphgen/models/vis/community_visualizer.py +48 -0
  30. graphgen/operators/__init__.py +13 -7
  31. graphgen/operators/generate/__init__.py +0 -0
  32. graphgen/operators/generate/generate_cot.py +117 -0
  33. graphgen/operators/judge.py +48 -87
  34. graphgen/operators/kg/__init__.py +0 -0
  35. graphgen/operators/{extract_kg.py → kg/extract_kg.py} +48 -29
  36. graphgen/operators/{merge_kg.py → kg/merge_kg.py} +38 -41
  37. graphgen/operators/{split_graph.py → kg/split_kg.py} +92 -44
  38. graphgen/operators/preprocess/__init__.py +0 -0
  39. graphgen/operators/{resolute_coreference.py → preprocess/resolute_coreference.py} +8 -8
  40. graphgen/operators/search/__init__.py +0 -0
  41. graphgen/operators/search/db/__init__.py +0 -0
  42. graphgen/operators/search/db/search_uniprot.py +0 -0
  43. graphgen/operators/search/kg/__init__.py +0 -0
  44. graphgen/operators/search/kg/search_wikipedia.py +58 -0
  45. graphgen/operators/search/search_all.py +82 -0
  46. graphgen/operators/search/web/__init__.py +0 -0
  47. graphgen/operators/search/web/search_bing.py +53 -0
  48. graphgen/operators/search/web/search_google.py +49 -0
  49. graphgen/operators/search_wikipedia.py +0 -71
  50. graphgen/operators/traverse_graph.py +199 -148
.env.example DELETED
@@ -1,6 +0,0 @@
1
- SYNTHESIZER_MODEL=
2
- SYNTHESIZER_BASE_URL=
3
- SYNTHESIZER_API_KEY=
4
- TRAINEE_MODEL=
5
- TRAINEE_BASE_URL=
6
- TRAINEE_API_KEY=
 
 
 
 
 
 
 
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore DELETED
@@ -1,179 +0,0 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # UV
98
- # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- #uv.lock
102
-
103
- # poetry
104
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
- # This is especially recommended for binary packages to ensure reproducibility, and is more
106
- # commonly ignored for libraries.
107
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
- #poetry.lock
109
-
110
- # pdm
111
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
- #pdm.lock
113
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
- # in version control.
115
- # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
- .pdm.toml
117
- .pdm-python
118
- .pdm-build/
119
-
120
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
- __pypackages__/
122
-
123
- # Celery stuff
124
- celerybeat-schedule
125
- celerybeat.pid
126
-
127
- # SageMath parsed files
128
- *.sage.py
129
-
130
- # Environments
131
- .env
132
- .venv
133
- env/
134
- venv/
135
- ENV/
136
- env.bak/
137
- venv.bak/
138
-
139
- # Spyder project settings
140
- .spyderproject
141
- .spyproject
142
-
143
- # Rope project settings
144
- .ropeproject
145
-
146
- # mkdocs documentation
147
- /site
148
-
149
- # mypy
150
- .mypy_cache/
151
- .dmypy.json
152
- dmypy.json
153
-
154
- # Pyre type checker
155
- .pyre/
156
-
157
- # pytype static type analyzer
158
- .pytype/
159
-
160
- # Cython debug symbols
161
- cython_debug/
162
-
163
- # PyCharm
164
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
- # and can be added to the global gitignore or merged into this file. For a more nuclear
167
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
- .idea/
169
-
170
- # Ruff stuff:
171
- .ruff_cache/
172
-
173
- # PyPI configuration file
174
- .pypirc
175
-
176
- cache
177
- *.pyc
178
- *.html
179
- .gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,14 +0,0 @@
1
- ---
2
- title: GraphGen
3
- emoji: 🐠
4
- colorFrom: gray
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.32.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: A framework for synthetic data generation based on KG.
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,20 +1,19 @@
 
1
  import os
2
  import sys
3
- import json
4
  import tempfile
5
 
6
- import pandas as pd
7
  import gradio as gr
8
-
9
- from gradio_i18n import Translate, gettext as _
10
-
11
- from webui.base import GraphGenParams
12
- from webui.test_api import test_api_connection
13
- from webui.cache_utils import setup_workspace, cleanup_workspace
14
- from webui.count_tokens import count_tokens
15
 
16
  # pylint: disable=wrong-import-position
17
- root_dir = os.path.dirname(os.path.abspath(__file__))
18
  sys.path.append(root_dir)
19
 
20
  from graphgen.graphgen import GraphGen
@@ -22,7 +21,6 @@ from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
22
  from graphgen.models.llm.limitter import RPM, TPM
23
  from graphgen.utils import set_logger
24
 
25
-
26
  css = """
27
  .center-row {
28
  display: flex;
@@ -37,9 +35,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
37
  log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
38
 
39
  set_logger(log_file, if_stream=False)
40
- graph_gen = GraphGen(
41
- working_dir=working_dir
42
- )
43
 
44
  # Set up LLM clients
45
  graph_gen.synthesizer_llm_client = OpenAIModel(
@@ -47,8 +43,8 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
47
  base_url=env.get("SYNTHESIZER_BASE_URL", ""),
48
  api_key=env.get("SYNTHESIZER_API_KEY", ""),
49
  request_limit=True,
50
- rpm= RPM(env.get("RPM", 1000)),
51
- tpm= TPM(env.get("TPM", 50000)),
52
  )
53
 
54
  graph_gen.trainee_llm_client = OpenAIModel(
@@ -56,16 +52,15 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
56
  base_url=env.get("TRAINEE_BASE_URL", ""),
57
  api_key=env.get("TRAINEE_API_KEY", ""),
58
  request_limit=True,
59
- rpm= RPM(env.get("RPM", 1000)),
60
- tpm= TPM(env.get("TPM", 50000)),
61
  )
62
 
63
- graph_gen.tokenizer_instance = Tokenizer(
64
- config.get("tokenizer", "cl100k_base"))
65
 
66
  strategy_config = config.get("traverse_strategy", {})
67
  graph_gen.traverse_strategy = TraverseStrategy(
68
- qa_form=config.get("qa_form"),
69
  expand_method=strategy_config.get("expand_method"),
70
  bidirectional=strategy_config.get("bidirectional"),
71
  max_extra_edges=strategy_config.get("max_extra_edges"),
@@ -73,11 +68,12 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
73
  max_depth=strategy_config.get("max_depth"),
74
  edge_sampling=strategy_config.get("edge_sampling"),
75
  isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
76
- loss_strategy=str(strategy_config.get("loss_strategy"))
77
  )
78
 
79
  return graph_gen
80
 
 
81
  # pylint: disable=too-many-statements
82
  def run_graphgen(params, progress=gr.Progress()):
83
  def sum_tokens(client):
@@ -87,10 +83,9 @@ def run_graphgen(params, progress=gr.Progress()):
87
  "if_trainee_model": params.if_trainee_model,
88
  "input_file": params.input_file,
89
  "tokenizer": params.tokenizer,
90
- "qa_form": params.qa_form,
91
- "web_search": False,
92
  "quiz_samples": params.quiz_samples,
93
  "traverse_strategy": {
 
94
  "bidirectional": params.bidirectional,
95
  "expand_method": params.expand_method,
96
  "max_extra_edges": params.max_extra_edges,
@@ -98,7 +93,7 @@ def run_graphgen(params, progress=gr.Progress()):
98
  "max_depth": params.max_depth,
99
  "edge_sampling": params.edge_sampling,
100
  "isolated_node_strategy": params.isolated_node_strategy,
101
- "loss_strategy": params.loss_strategy
102
  },
103
  "chunk_size": params.chunk_size,
104
  }
@@ -115,11 +110,15 @@ def run_graphgen(params, progress=gr.Progress()):
115
  }
116
 
117
  # Test API connection
118
- test_api_connection(env["SYNTHESIZER_BASE_URL"],
119
- env["SYNTHESIZER_API_KEY"], env["SYNTHESIZER_MODEL"])
120
- if config['if_trainee_model']:
121
- test_api_connection(env["TRAINEE_BASE_URL"],
122
- env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"])
 
 
 
 
123
 
124
  # Initialize GraphGen
125
  graph_gen = init_graph_gen(config, env)
@@ -129,7 +128,7 @@ def run_graphgen(params, progress=gr.Progress()):
129
 
130
  try:
131
  # Load input data
132
- file = config['input_file']
133
  if isinstance(file, list):
134
  file = file[0]
135
 
@@ -137,24 +136,22 @@ def run_graphgen(params, progress=gr.Progress()):
137
 
138
  if file.endswith(".jsonl"):
139
  data_type = "raw"
140
- with open(file, "r", encoding='utf-8') as f:
141
  data.extend(json.loads(line) for line in f)
142
  elif file.endswith(".json"):
143
  data_type = "chunked"
144
- with open(file, "r", encoding='utf-8') as f:
145
  data.extend(json.load(f))
146
  elif file.endswith(".txt"):
147
  # 读取文件后根据chunk_size转成raw格式的数据
148
  data_type = "raw"
149
  content = ""
150
- with open(file, "r", encoding='utf-8') as f:
151
  lines = f.readlines()
152
  for line in lines:
153
  content += line.strip() + " "
154
  size = int(config.get("chunk_size", 512))
155
- chunks = [
156
- content[i:i + size] for i in range(0, len(content), size)
157
- ]
158
  data.extend([{"content": chunk} for chunk in chunks])
159
  else:
160
  raise ValueError(f"Unsupported file type: {file}")
@@ -162,9 +159,9 @@ def run_graphgen(params, progress=gr.Progress()):
162
  # Process the data
163
  graph_gen.insert(data, data_type)
164
 
165
- if config['if_trainee_model']:
166
  # Generate quiz
167
- graph_gen.quiz(max_samples=config['quiz_samples'])
168
 
169
  # Judge statements
170
  graph_gen.judge()
@@ -174,47 +171,44 @@ def run_graphgen(params, progress=gr.Progress()):
174
  graph_gen.judge(skip=True)
175
 
176
  # Traverse graph
177
- graph_gen.traverse()
178
 
179
  # Save output
180
  output_data = graph_gen.qa_storage.data
181
  with tempfile.NamedTemporaryFile(
182
- mode="w",
183
- suffix=".jsonl",
184
- delete=False,
185
- encoding="utf-8") as tmpfile:
186
  json.dump(output_data, tmpfile, ensure_ascii=False)
187
  output_file = tmpfile.name
188
 
189
  synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
190
- trainee_tokens = sum_tokens(graph_gen.trainee_llm_client) if config['if_trainee_model'] else 0
 
 
 
 
191
  total_tokens = synthesizer_tokens + trainee_tokens
192
 
193
  data_frame = params.token_counter
194
  try:
195
  _update_data = [
196
- [
197
- data_frame.iloc[0, 0],
198
- data_frame.iloc[0, 1],
199
- str(total_tokens)
200
- ]
201
  ]
202
- new_df = pd.DataFrame(
203
- _update_data,
204
- columns=data_frame.columns
205
- )
206
  data_frame = new_df
207
 
208
  except Exception as e:
209
  raise gr.Error(f"DataFrame operation error: {str(e)}")
210
 
211
- return output_file, gr.DataFrame(label='Token Stats',
212
- headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
213
- datatype="str",
214
- interactive=False,
215
- value=data_frame,
216
- visible=True,
217
- wrap=True)
 
 
218
 
219
  except Exception as e: # pylint: disable=broad-except
220
  raise gr.Error(f"Error occurred: {str(e)}")
@@ -223,16 +217,18 @@ def run_graphgen(params, progress=gr.Progress()):
223
  # Clean up workspace
224
  cleanup_workspace(graph_gen.working_dir)
225
 
226
- with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
227
- css=css) as demo):
228
  # Header
229
- gr.Image(value="https://github.com/open-sciencelab/GraphGen/blob/main/resources/images/logo.png?raw=true",
230
- label="GraphGen Banner",
231
- elem_id="banner",
232
- interactive=False,
233
- container=False,
234
- show_download_button=False,
235
- show_fullscreen_button=False)
 
 
236
  lang_btn = gr.Radio(
237
  choices=[
238
  ("English", "en"),
@@ -245,7 +241,8 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
245
  elem_classes=["center-row"],
246
  )
247
 
248
- gr.HTML("""
 
249
  <div style="display: flex; gap: 8px; margin-left: auto; align-items: center; justify-content: center;">
250
  <a href="https://github.com/open-sciencelab/GraphGen/releases">
251
  <img src="https://img.shields.io/badge/Version-v0.1.0-blue" alt="Version">
@@ -260,80 +257,98 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
260
  <img src="https://img.shields.io/badge/arXiv-pdf-yellow" alt="arXiv">
261
  </a>
262
  </div>
263
- """)
 
264
  with Translate(
265
- os.path.join(root_dir, 'webui', 'translation.json'),
266
- lang_btn,
267
- placeholder_langs=["en", "zh"],
268
- persistant=
269
- False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
270
  ):
271
  lang_btn.render()
272
 
273
  gr.Markdown(
274
- value = "# " + _("Title") + "\n\n" + \
275
- "### [GraphGen](https://github.com/open-sciencelab/GraphGen) " + _("Intro")
 
 
 
276
  )
277
 
278
- if_trainee_model = gr.Checkbox(label=_("Use Trainee Model"),
279
- value=False,
280
- interactive=True)
281
 
282
  with gr.Accordion(label=_("Model Config"), open=False):
283
- synthesizer_url = gr.Textbox(label="Synthesizer URL",
284
- value="https://api.siliconflow.cn/v1",
285
- info=_("Synthesizer URL Info"),
286
- interactive=True)
287
- synthesizer_model = gr.Textbox(label="Synthesizer Model",
288
- value="Qwen/Qwen2.5-7B-Instruct",
289
- info=_("Synthesizer Model Info"),
290
- interactive=True)
291
- trainee_url = gr.Textbox(label="Trainee URL",
292
- value="https://api.siliconflow.cn/v1",
293
- info=_("Trainee URL Info"),
294
- interactive=True,
295
- visible=if_trainee_model.value is True)
 
 
 
 
 
 
296
  trainee_model = gr.Textbox(
297
  label="Trainee Model",
298
  value="Qwen/Qwen2.5-7B-Instruct",
299
  info=_("Trainee Model Info"),
300
  interactive=True,
301
- visible=if_trainee_model.value is True)
 
302
  trainee_api_key = gr.Textbox(
303
- label=_("SiliconCloud Token for Trainee Model"),
304
- type="password",
305
- value="",
306
- info="https://cloud.siliconflow.cn/account/ak",
307
- visible=if_trainee_model.value is True)
308
-
309
 
310
  with gr.Accordion(label=_("Generation Config"), open=False):
311
- chunk_size = gr.Slider(label="Chunk Size",
312
- minimum=256,
313
- maximum=4096,
314
- value=512,
315
- step=256,
316
- interactive=True)
317
- tokenizer = gr.Textbox(label="Tokenizer",
318
- value="cl100k_base",
319
- interactive=True)
320
- qa_form = gr.Radio(choices=["atomic", "multi_hop", "aggregated"],
321
- label="QA Form",
322
- value="aggregated",
323
- interactive=True)
324
- quiz_samples = gr.Number(label="Quiz Samples",
325
- value=2,
326
- minimum=1,
327
- interactive=True,
328
- visible=if_trainee_model.value is True)
329
- bidirectional = gr.Checkbox(label="Bidirectional",
330
- value=True,
331
- interactive=True)
332
-
333
- expand_method = gr.Radio(choices=["max_width", "max_tokens"],
334
- label="Expand Method",
335
- value="max_tokens",
336
- interactive=True)
 
 
 
 
 
 
 
 
337
  max_extra_edges = gr.Slider(
338
  minimum=1,
339
  maximum=10,
@@ -341,44 +356,54 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
341
  label="Max Extra Edges",
342
  step=1,
343
  interactive=True,
344
- visible=expand_method.value == "max_width")
345
- max_tokens = gr.Slider(minimum=64,
346
- maximum=1024,
347
- value=256,
348
- label="Max Tokens",
349
- step=64,
350
- interactive=True,
351
- visible=(expand_method.value
352
- != "max_width"))
353
-
354
- max_depth = gr.Slider(minimum=1,
355
- maximum=5,
356
- value=2,
357
- label="Max Depth",
358
- step=1,
359
- interactive=True)
 
 
 
 
360
  edge_sampling = gr.Radio(
361
  choices=["max_loss", "min_loss", "random"],
362
  label="Edge Sampling",
363
  value="max_loss",
364
  interactive=True,
365
- visible=if_trainee_model.value is True)
366
- isolated_node_strategy = gr.Radio(choices=["add", "ignore"],
367
- label="Isolated Node Strategy",
368
- value="ignore",
369
- interactive=True)
370
- loss_strategy = gr.Radio(choices=["only_edge", "both"],
371
- label="Loss Strategy",
372
- value="only_edge",
373
- interactive=True)
 
 
 
 
 
374
 
375
  with gr.Row(equal_height=True):
376
  with gr.Column(scale=3):
377
  api_key = gr.Textbox(
378
- label=_("SiliconCloud Token"),
379
  type="password",
380
  value="",
381
- info="https://cloud.siliconflow.cn/account/ak")
 
382
  with gr.Column(scale=1):
383
  test_connection_btn = gr.Button(_("Test Connection"))
384
 
@@ -392,7 +417,8 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
392
  value=1000,
393
  step=100,
394
  interactive=True,
395
- visible=True)
 
396
  with gr.Column():
397
  tpm = gr.Slider(
398
  label="TPM",
@@ -401,8 +427,8 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
401
  value=50000,
402
  step=1000,
403
  interactive=True,
404
- visible=True)
405
-
406
 
407
  with gr.Blocks():
408
  with gr.Row(equal_height=True):
@@ -413,15 +439,17 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
413
  file_types=[".txt", ".json", ".jsonl"],
414
  interactive=True,
415
  )
416
- examples_dir = os.path.join(root_dir, 'webui', 'examples')
417
- gr.Examples(examples=[
418
- [os.path.join(examples_dir, "txt_demo.txt")],
419
- [os.path.join(examples_dir, "raw_demo.jsonl")],
420
- [os.path.join(examples_dir, "chunked_demo.json")],
421
- ],
422
- inputs=upload_file,
423
- label=_("Example Files"),
424
- examples_per_page=3)
 
 
425
  with gr.Column(scale=1):
426
  output = gr.File(
427
  label="Output(See Github FAQ)",
@@ -430,12 +458,18 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
430
  )
431
 
432
  with gr.Blocks():
433
- token_counter = gr.DataFrame(label='Token Stats',
434
- headers=["Source Text Token Count", "Estimated Token Usage", "Token Used"],
435
- datatype="str",
436
- interactive=False,
437
- visible=False,
438
- wrap=True)
 
 
 
 
 
 
439
 
440
  submit_btn = gr.Button(_("Run GraphGen"))
441
 
@@ -443,23 +477,36 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
443
  test_connection_btn.click(
444
  test_api_connection,
445
  inputs=[synthesizer_url, api_key, synthesizer_model],
446
- outputs=[])
 
447
 
448
  if if_trainee_model.value:
449
- test_connection_btn.click(test_api_connection,
450
- inputs=[trainee_url, api_key, trainee_model],
451
- outputs=[])
 
 
452
 
453
- expand_method.change(lambda method:
454
- (gr.update(visible=method == "max_width"),
455
- gr.update(visible=method != "max_width")),
456
- inputs=expand_method,
457
- outputs=[max_extra_edges, max_tokens])
 
 
 
458
 
459
  if_trainee_model.change(
460
  lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
461
  inputs=if_trainee_model,
462
- outputs=[trainee_url, trainee_model, quiz_samples, edge_sampling, trainee_api_key])
 
 
 
 
 
 
 
463
 
464
  upload_file.change(
465
  lambda x: (gr.update(visible=True)),
@@ -479,41 +526,61 @@ with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
479
  )
480
 
481
  submit_btn.click(
482
- lambda *args: run_graphgen(GraphGenParams(
483
- if_trainee_model=args[0],
484
- input_file=args[1],
485
- tokenizer=args[2],
486
- qa_form=args[3],
487
- bidirectional=args[4],
488
- expand_method=args[5],
489
- max_extra_edges=args[6],
490
- max_tokens=args[7],
491
- max_depth=args[8],
492
- edge_sampling=args[9],
493
- isolated_node_strategy=args[10],
494
- loss_strategy=args[11],
495
- synthesizer_url=args[12],
496
- synthesizer_model=args[13],
497
- trainee_model=args[14],
498
- api_key=args[15],
499
- chunk_size=args[16],
500
- rpm=args[17],
501
- tpm=args[18],
502
- quiz_samples=args[19],
503
- trainee_url=args[20],
504
- trainee_api_key=args[21],
505
- token_counter=args[22],
506
- )),
 
 
507
  inputs=[
508
- if_trainee_model, upload_file, tokenizer, qa_form,
509
- bidirectional, expand_method, max_extra_edges, max_tokens,
510
- max_depth, edge_sampling, isolated_node_strategy,
511
- loss_strategy, synthesizer_url, synthesizer_model, trainee_model,
512
- api_key, chunk_size, rpm, tpm, quiz_samples, trainee_url, trainee_api_key, token_counter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  ],
514
  outputs=[output, token_counter],
515
  )
516
 
517
  if __name__ == "__main__":
518
  demo.queue(api_open=False, default_concurrency_limit=2)
519
- demo.launch(server_name='0.0.0.0')
 
1
+ import json
2
  import os
3
  import sys
 
4
  import tempfile
5
 
 
6
  import gradio as gr
7
+ import pandas as pd
8
+ from base import GraphGenParams
9
+ from cache_utils import cleanup_workspace, setup_workspace
10
+ from count_tokens import count_tokens
11
+ from gradio_i18n import Translate
12
+ from gradio_i18n import gettext as _
13
+ from test_api import test_api_connection
14
 
15
  # pylint: disable=wrong-import-position
16
+ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
  sys.path.append(root_dir)
18
 
19
  from graphgen.graphgen import GraphGen
 
21
  from graphgen.models.llm.limitter import RPM, TPM
22
  from graphgen.utils import set_logger
23
 
 
24
  css = """
25
  .center-row {
26
  display: flex;
 
35
  log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
36
 
37
  set_logger(log_file, if_stream=False)
38
+ graph_gen = GraphGen(working_dir=working_dir)
 
 
39
 
40
  # Set up LLM clients
41
  graph_gen.synthesizer_llm_client = OpenAIModel(
 
43
  base_url=env.get("SYNTHESIZER_BASE_URL", ""),
44
  api_key=env.get("SYNTHESIZER_API_KEY", ""),
45
  request_limit=True,
46
+ rpm=RPM(env.get("RPM", 1000)),
47
+ tpm=TPM(env.get("TPM", 50000)),
48
  )
49
 
50
  graph_gen.trainee_llm_client = OpenAIModel(
 
52
  base_url=env.get("TRAINEE_BASE_URL", ""),
53
  api_key=env.get("TRAINEE_API_KEY", ""),
54
  request_limit=True,
55
+ rpm=RPM(env.get("RPM", 1000)),
56
+ tpm=TPM(env.get("TPM", 50000)),
57
  )
58
 
59
+ graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
 
60
 
61
  strategy_config = config.get("traverse_strategy", {})
62
  graph_gen.traverse_strategy = TraverseStrategy(
63
+ qa_form=strategy_config.get("qa_form"),
64
  expand_method=strategy_config.get("expand_method"),
65
  bidirectional=strategy_config.get("bidirectional"),
66
  max_extra_edges=strategy_config.get("max_extra_edges"),
 
68
  max_depth=strategy_config.get("max_depth"),
69
  edge_sampling=strategy_config.get("edge_sampling"),
70
  isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
71
+ loss_strategy=str(strategy_config.get("loss_strategy")),
72
  )
73
 
74
  return graph_gen
75
 
76
+
77
  # pylint: disable=too-many-statements
78
  def run_graphgen(params, progress=gr.Progress()):
79
  def sum_tokens(client):
 
83
  "if_trainee_model": params.if_trainee_model,
84
  "input_file": params.input_file,
85
  "tokenizer": params.tokenizer,
 
 
86
  "quiz_samples": params.quiz_samples,
87
  "traverse_strategy": {
88
+ "qa_form": params.qa_form,
89
  "bidirectional": params.bidirectional,
90
  "expand_method": params.expand_method,
91
  "max_extra_edges": params.max_extra_edges,
 
93
  "max_depth": params.max_depth,
94
  "edge_sampling": params.edge_sampling,
95
  "isolated_node_strategy": params.isolated_node_strategy,
96
+ "loss_strategy": params.loss_strategy,
97
  },
98
  "chunk_size": params.chunk_size,
99
  }
 
110
  }
111
 
112
  # Test API connection
113
+ test_api_connection(
114
+ env["SYNTHESIZER_BASE_URL"],
115
+ env["SYNTHESIZER_API_KEY"],
116
+ env["SYNTHESIZER_MODEL"],
117
+ )
118
+ if config["if_trainee_model"]:
119
+ test_api_connection(
120
+ env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
121
+ )
122
 
123
  # Initialize GraphGen
124
  graph_gen = init_graph_gen(config, env)
 
128
 
129
  try:
130
  # Load input data
131
+ file = config["input_file"]
132
  if isinstance(file, list):
133
  file = file[0]
134
 
 
136
 
137
  if file.endswith(".jsonl"):
138
  data_type = "raw"
139
+ with open(file, "r", encoding="utf-8") as f:
140
  data.extend(json.loads(line) for line in f)
141
  elif file.endswith(".json"):
142
  data_type = "chunked"
143
+ with open(file, "r", encoding="utf-8") as f:
144
  data.extend(json.load(f))
145
  elif file.endswith(".txt"):
146
  # 读取文件后根据chunk_size转成raw格式的数据
147
  data_type = "raw"
148
  content = ""
149
+ with open(file, "r", encoding="utf-8") as f:
150
  lines = f.readlines()
151
  for line in lines:
152
  content += line.strip() + " "
153
  size = int(config.get("chunk_size", 512))
154
+ chunks = [content[i : i + size] for i in range(0, len(content), size)]
 
 
155
  data.extend([{"content": chunk} for chunk in chunks])
156
  else:
157
  raise ValueError(f"Unsupported file type: {file}")
 
159
  # Process the data
160
  graph_gen.insert(data, data_type)
161
 
162
+ if config["if_trainee_model"]:
163
  # Generate quiz
164
+ graph_gen.quiz(max_samples=config["quiz_samples"])
165
 
166
  # Judge statements
167
  graph_gen.judge()
 
171
  graph_gen.judge(skip=True)
172
 
173
  # Traverse graph
174
+ graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
175
 
176
  # Save output
177
  output_data = graph_gen.qa_storage.data
178
  with tempfile.NamedTemporaryFile(
179
+ mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
180
+ ) as tmpfile:
 
 
181
  json.dump(output_data, tmpfile, ensure_ascii=False)
182
  output_file = tmpfile.name
183
 
184
  synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
185
+ trainee_tokens = (
186
+ sum_tokens(graph_gen.trainee_llm_client)
187
+ if config["if_trainee_model"]
188
+ else 0
189
+ )
190
  total_tokens = synthesizer_tokens + trainee_tokens
191
 
192
  data_frame = params.token_counter
193
  try:
194
  _update_data = [
195
+ [data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
 
 
 
 
196
  ]
197
+ new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
 
 
 
198
  data_frame = new_df
199
 
200
  except Exception as e:
201
  raise gr.Error(f"DataFrame operation error: {str(e)}")
202
 
203
+ return output_file, gr.DataFrame(
204
+ label="Token Stats",
205
+ headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
206
+ datatype="str",
207
+ interactive=False,
208
+ value=data_frame,
209
+ visible=True,
210
+ wrap=True,
211
+ )
212
 
213
  except Exception as e: # pylint: disable=broad-except
214
  raise gr.Error(f"Error occurred: {str(e)}")
 
217
  # Clean up workspace
218
  cleanup_workspace(graph_gen.working_dir)
219
 
220
+
221
+ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
222
  # Header
223
+ gr.Image(
224
+ value=os.path.join(root_dir, "resources", "images", "logo.png"),
225
+ label="GraphGen Banner",
226
+ elem_id="banner",
227
+ interactive=False,
228
+ container=False,
229
+ show_download_button=False,
230
+ show_fullscreen_button=False,
231
+ )
232
  lang_btn = gr.Radio(
233
  choices=[
234
  ("English", "en"),
 
241
  elem_classes=["center-row"],
242
  )
243
 
244
+ gr.HTML(
245
+ """
246
  <div style="display: flex; gap: 8px; margin-left: auto; align-items: center; justify-content: center;">
247
  <a href="https://github.com/open-sciencelab/GraphGen/releases">
248
  <img src="https://img.shields.io/badge/Version-v0.1.0-blue" alt="Version">
 
257
  <img src="https://img.shields.io/badge/arXiv-pdf-yellow" alt="arXiv">
258
  </a>
259
  </div>
260
+ """
261
+ )
262
  with Translate(
263
+ os.path.join(root_dir, "webui", "translation.json"),
264
+ lang_btn,
265
+ placeholder_langs=["en", "zh"],
266
+ persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
 
267
  ):
268
  lang_btn.render()
269
 
270
  gr.Markdown(
271
+ value="# "
272
+ + _("Title")
273
+ + "\n\n"
274
+ + "### [GraphGen](https://github.com/open-sciencelab/GraphGen) "
275
+ + _("Intro")
276
  )
277
 
278
+ if_trainee_model = gr.Checkbox(
279
+ label=_("Use Trainee Model"), value=False, interactive=True
280
+ )
281
 
282
  with gr.Accordion(label=_("Model Config"), open=False):
283
+ synthesizer_url = gr.Textbox(
284
+ label="Synthesizer URL",
285
+ value="https://api.siliconflow.cn/v1",
286
+ info=_("Synthesizer URL Info"),
287
+ interactive=True,
288
+ )
289
+ synthesizer_model = gr.Textbox(
290
+ label="Synthesizer Model",
291
+ value="Qwen/Qwen2.5-7B-Instruct",
292
+ info=_("Synthesizer Model Info"),
293
+ interactive=True,
294
+ )
295
+ trainee_url = gr.Textbox(
296
+ label="Trainee URL",
297
+ value="https://api.siliconflow.cn/v1",
298
+ info=_("Trainee URL Info"),
299
+ interactive=True,
300
+ visible=if_trainee_model.value is True,
301
+ )
302
  trainee_model = gr.Textbox(
303
  label="Trainee Model",
304
  value="Qwen/Qwen2.5-7B-Instruct",
305
  info=_("Trainee Model Info"),
306
  interactive=True,
307
+ visible=if_trainee_model.value is True,
308
+ )
309
  trainee_api_key = gr.Textbox(
310
+ label=_("SiliconFlow Token for Trainee Model"),
311
+ type="password",
312
+ value="",
313
+ info="https://cloud.siliconflow.cn/account/ak",
314
+ visible=if_trainee_model.value is True,
315
+ )
316
 
317
  with gr.Accordion(label=_("Generation Config"), open=False):
318
+ chunk_size = gr.Slider(
319
+ label="Chunk Size",
320
+ minimum=256,
321
+ maximum=4096,
322
+ value=512,
323
+ step=256,
324
+ interactive=True,
325
+ )
326
+ tokenizer = gr.Textbox(
327
+ label="Tokenizer", value="cl100k_base", interactive=True
328
+ )
329
+ qa_form = gr.Radio(
330
+ choices=["atomic", "multi_hop", "aggregated"],
331
+ label="QA Form",
332
+ value="aggregated",
333
+ interactive=True,
334
+ )
335
+ quiz_samples = gr.Number(
336
+ label="Quiz Samples",
337
+ value=2,
338
+ minimum=1,
339
+ interactive=True,
340
+ visible=if_trainee_model.value is True,
341
+ )
342
+ bidirectional = gr.Checkbox(
343
+ label="Bidirectional", value=True, interactive=True
344
+ )
345
+
346
+ expand_method = gr.Radio(
347
+ choices=["max_width", "max_tokens"],
348
+ label="Expand Method",
349
+ value="max_tokens",
350
+ interactive=True,
351
+ )
352
  max_extra_edges = gr.Slider(
353
  minimum=1,
354
  maximum=10,
 
356
  label="Max Extra Edges",
357
  step=1,
358
  interactive=True,
359
+ visible=expand_method.value == "max_width",
360
+ )
361
+ max_tokens = gr.Slider(
362
+ minimum=64,
363
+ maximum=1024,
364
+ value=256,
365
+ label="Max Tokens",
366
+ step=64,
367
+ interactive=True,
368
+ visible=(expand_method.value != "max_width"),
369
+ )
370
+
371
+ max_depth = gr.Slider(
372
+ minimum=1,
373
+ maximum=5,
374
+ value=2,
375
+ label="Max Depth",
376
+ step=1,
377
+ interactive=True,
378
+ )
379
  edge_sampling = gr.Radio(
380
  choices=["max_loss", "min_loss", "random"],
381
  label="Edge Sampling",
382
  value="max_loss",
383
  interactive=True,
384
+ visible=if_trainee_model.value is True,
385
+ )
386
+ isolated_node_strategy = gr.Radio(
387
+ choices=["add", "ignore"],
388
+ label="Isolated Node Strategy",
389
+ value="ignore",
390
+ interactive=True,
391
+ )
392
+ loss_strategy = gr.Radio(
393
+ choices=["only_edge", "both"],
394
+ label="Loss Strategy",
395
+ value="only_edge",
396
+ interactive=True,
397
+ )
398
 
399
  with gr.Row(equal_height=True):
400
  with gr.Column(scale=3):
401
  api_key = gr.Textbox(
402
+ label=_("SiliconFlow Token"),
403
  type="password",
404
  value="",
405
+ info="https://cloud.siliconflow.cn/account/ak",
406
+ )
407
  with gr.Column(scale=1):
408
  test_connection_btn = gr.Button(_("Test Connection"))
409
 
 
417
  value=1000,
418
  step=100,
419
  interactive=True,
420
+ visible=True,
421
+ )
422
  with gr.Column():
423
  tpm = gr.Slider(
424
  label="TPM",
 
427
  value=50000,
428
  step=1000,
429
  interactive=True,
430
+ visible=True,
431
+ )
432
 
433
  with gr.Blocks():
434
  with gr.Row(equal_height=True):
 
439
  file_types=[".txt", ".json", ".jsonl"],
440
  interactive=True,
441
  )
442
+ examples_dir = os.path.join(root_dir, "webui", "examples")
443
+ gr.Examples(
444
+ examples=[
445
+ [os.path.join(examples_dir, "txt_demo.txt")],
446
+ [os.path.join(examples_dir, "raw_demo.jsonl")],
447
+ [os.path.join(examples_dir, "chunked_demo.json")],
448
+ ],
449
+ inputs=upload_file,
450
+ label=_("Example Files"),
451
+ examples_per_page=3,
452
+ )
453
  with gr.Column(scale=1):
454
  output = gr.File(
455
  label="Output(See Github FAQ)",
 
458
  )
459
 
460
  with gr.Blocks():
461
+ token_counter = gr.DataFrame(
462
+ label="Token Stats",
463
+ headers=[
464
+ "Source Text Token Count",
465
+ "Estimated Token Usage",
466
+ "Token Used",
467
+ ],
468
+ datatype="str",
469
+ interactive=False,
470
+ visible=False,
471
+ wrap=True,
472
+ )
473
 
474
  submit_btn = gr.Button(_("Run GraphGen"))
475
 
 
477
  test_connection_btn.click(
478
  test_api_connection,
479
  inputs=[synthesizer_url, api_key, synthesizer_model],
480
+ outputs=[],
481
+ )
482
 
483
  if if_trainee_model.value:
484
+ test_connection_btn.click(
485
+ test_api_connection,
486
+ inputs=[trainee_url, api_key, trainee_model],
487
+ outputs=[],
488
+ )
489
 
490
+ expand_method.change(
491
+ lambda method: (
492
+ gr.update(visible=method == "max_width"),
493
+ gr.update(visible=method != "max_width"),
494
+ ),
495
+ inputs=expand_method,
496
+ outputs=[max_extra_edges, max_tokens],
497
+ )
498
 
499
  if_trainee_model.change(
500
  lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
501
  inputs=if_trainee_model,
502
+ outputs=[
503
+ trainee_url,
504
+ trainee_model,
505
+ quiz_samples,
506
+ edge_sampling,
507
+ trainee_api_key,
508
+ ],
509
+ )
510
 
511
  upload_file.change(
512
  lambda x: (gr.update(visible=True)),
 
526
  )
527
 
528
  submit_btn.click(
529
+ lambda *args: run_graphgen(
530
+ GraphGenParams(
531
+ if_trainee_model=args[0],
532
+ input_file=args[1],
533
+ tokenizer=args[2],
534
+ qa_form=args[3],
535
+ bidirectional=args[4],
536
+ expand_method=args[5],
537
+ max_extra_edges=args[6],
538
+ max_tokens=args[7],
539
+ max_depth=args[8],
540
+ edge_sampling=args[9],
541
+ isolated_node_strategy=args[10],
542
+ loss_strategy=args[11],
543
+ synthesizer_url=args[12],
544
+ synthesizer_model=args[13],
545
+ trainee_model=args[14],
546
+ api_key=args[15],
547
+ chunk_size=args[16],
548
+ rpm=args[17],
549
+ tpm=args[18],
550
+ quiz_samples=args[19],
551
+ trainee_url=args[20],
552
+ trainee_api_key=args[21],
553
+ token_counter=args[22],
554
+ )
555
+ ),
556
  inputs=[
557
+ if_trainee_model,
558
+ upload_file,
559
+ tokenizer,
560
+ qa_form,
561
+ bidirectional,
562
+ expand_method,
563
+ max_extra_edges,
564
+ max_tokens,
565
+ max_depth,
566
+ edge_sampling,
567
+ isolated_node_strategy,
568
+ loss_strategy,
569
+ synthesizer_url,
570
+ synthesizer_model,
571
+ trainee_model,
572
+ api_key,
573
+ chunk_size,
574
+ rpm,
575
+ tpm,
576
+ quiz_samples,
577
+ trainee_url,
578
+ trainee_api_key,
579
+ token_counter,
580
  ],
581
  outputs=[output, token_counter],
582
  )
583
 
584
  if __name__ == "__main__":
585
  demo.queue(api_open=False, default_concurrency_limit=2)
586
+ demo.launch(server_name="0.0.0.0")
graphgen/configs/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Configs for GraphGen
graphgen/configs/aggregated_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_data_type: raw # raw, chunked
2
+ input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ output_data_type: aggregated # atomic, aggregated, multi_hop, cot
4
+ output_data_format: ChatML # Alpaca, Sharegpt, ChatML
5
+ tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
+ search: # web search configuration
7
+ enabled: false # whether to enable web search
8
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
+ enabled: true
11
+ quiz_samples: 2 # number of quiz samples to generate
12
+ re_judge: false # whether to re-judge the existing quiz samples
13
+ traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
14
+ bidirectional: true # whether to traverse the graph in both directions
15
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
16
+ expand_method: max_width # expand method, support: max_width, max_depth
17
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
18
+ max_depth: 5 # maximum depth for graph traversal
19
+ max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
20
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
21
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
graphgen/configs/atomic_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_data_type: raw # raw, chunked
2
+ input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ output_data_type: atomic # atomic, aggregated, multi_hop, cot
4
+ output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
5
+ tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
+ search: # web search configuration
7
+ enabled: false # whether to enable web search
8
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
+ enabled: true
11
+ quiz_samples: 2 # number of quiz samples to generate
12
+ re_judge: false # whether to re-judge the existing quiz samples
13
+ traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
14
+ bidirectional: true # whether to traverse the graph in both directions
15
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
16
+ expand_method: max_width # expand method, support: max_width, max_depth
17
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
18
+ max_depth: 3 # maximum depth for graph traversal
19
+ max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
20
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
21
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
graphgen/configs/config.yaml.example DELETED
@@ -1,16 +0,0 @@
1
- data_type: raw
2
- input_file: resources/examples/raw_demo.jsonl
3
- tokenizer: cl100k_base
4
- quiz_samples: 2
5
- traverse_strategy:
6
- qa_form: atomic
7
- bidirectional: true
8
- edge_sampling: max_loss
9
- expand_method: max_tokens
10
- isolated_node_strategy: add
11
- max_depth: 2
12
- max_extra_edges: 5
13
- max_tokens: 256
14
- loss_strategy: only_edge
15
- web_search: false
16
- re_judge: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/configs/cot_config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_data_type: raw # raw, chunked
2
+ input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ output_data_type: cot # atomic, aggregated, multi_hop, cot
4
+ output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
5
+ tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
+ search: # web search configuration
7
+ enabled: false # whether to enable web search
8
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ method_params:
10
+ method: leiden
11
+ max_size: 20 # Maximum size of communities
12
+ use_lcc: false
13
+ random_seed: 42
graphgen/configs/graphgen_config.yaml DELETED
@@ -1,16 +0,0 @@
1
- data_type: raw
2
- input_file: resources/examples/raw_demo.jsonl
3
- tokenizer: cl100k_base
4
- quiz_samples: 2
5
- traverse_strategy:
6
- qa_form: aggregated
7
- bidirectional: true
8
- edge_sampling: max_loss
9
- expand_method: max_width
10
- isolated_node_strategy: ignore
11
- max_depth: 1
12
- max_extra_edges: 2
13
- max_tokens: 256
14
- loss_strategy: only_edge
15
- web_search: false
16
- re_judge: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/configs/multi_hop_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_data_type: raw # raw, chunked
2
+ input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
4
+ output_data_format: ChatML # Alpaca, Sharegpt, ChatML
5
+ tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
+ search: # web search configuration
7
+ enabled: false # whether to enable web search
8
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
+ enabled: true
11
+ quiz_samples: 2 # number of quiz samples to generate
12
+ re_judge: false # whether to re-judge the existing quiz samples
13
+ traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
14
+ bidirectional: true # whether to traverse the graph in both directions
15
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
16
+ expand_method: max_width # expand method, support: max_width, max_depth
17
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
18
+ max_depth: 1 # maximum depth for graph traversal
19
+ max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
20
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
21
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
graphgen/generate.py CHANGED
@@ -1,101 +1,103 @@
 
1
  import os
2
- import json
3
  import time
4
- import argparse
5
  from importlib.resources import files
 
6
  import yaml
7
  from dotenv import load_dotenv
8
 
9
  from .graphgen import GraphGen
10
- from .models import OpenAIModel, Tokenizer, TraverseStrategy
11
- from .utils import set_logger
12
 
13
  sys_path = os.path.abspath(os.path.dirname(__file__))
14
 
15
  load_dotenv()
16
 
 
17
  def set_working_dir(folder):
18
  os.makedirs(folder, exist_ok=True)
19
  os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
20
  os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
21
 
 
22
  def save_config(config_path, global_config):
23
  if not os.path.exists(os.path.dirname(config_path)):
24
  os.makedirs(os.path.dirname(config_path))
25
- with open(config_path, "w", encoding='utf-8') as config_file:
26
- yaml.dump(global_config, config_file, default_flow_style=False, allow_unicode=True)
 
 
 
27
 
28
  def main():
29
  parser = argparse.ArgumentParser()
30
- parser.add_argument('--config_file',
31
- help='Config parameters for GraphGen.',
32
- # default=os.path.join(sys_path, "configs", "graphgen_config.yaml"),
33
- default=files('graphgen').joinpath("configs", "graphgen_config.yaml"),
34
- type=str)
35
- parser.add_argument('--output_dir',
36
- help='Output directory for GraphGen.',
37
- default=sys_path,
38
- required=True,
39
- type=str)
 
 
 
40
 
41
  args = parser.parse_args()
42
 
43
  working_dir = args.output_dir
44
  set_working_dir(working_dir)
45
- unique_id = int(time.time())
46
- set_logger(os.path.join(working_dir, "logs", f"graphgen_{unique_id}.log"), if_stream=False)
47
 
48
- with open(args.config_file, "r", encoding='utf-8') as f:
49
  config = yaml.load(f, Loader=yaml.FullLoader)
50
 
51
- input_file = config['input_file']
52
-
53
- if config['data_type'] == 'raw':
54
- with open(input_file, "r", encoding='utf-8') as f:
55
- data = [json.loads(line) for line in f]
56
- elif config['data_type'] == 'chunked':
57
- with open(input_file, "r", encoding='utf-8') as f:
58
- data = json.load(f)
59
- else:
60
- raise ValueError(f"Invalid data type: {config['data_type']}")
61
-
62
- synthesizer_llm_client = OpenAIModel(
63
- model_name=os.getenv("SYNTHESIZER_MODEL"),
64
- api_key=os.getenv("SYNTHESIZER_API_KEY"),
65
- base_url=os.getenv("SYNTHESIZER_BASE_URL")
66
- )
67
- trainee_llm_client = OpenAIModel(
68
- model_name=os.getenv("TRAINEE_MODEL"),
69
- api_key=os.getenv("TRAINEE_API_KEY"),
70
- base_url=os.getenv("TRAINEE_BASE_URL")
71
- )
72
-
73
- traverse_strategy = TraverseStrategy(
74
- **config['traverse_strategy']
75
  )
76
-
77
- graph_gen = GraphGen(
78
- working_dir=working_dir,
79
- unique_id=unique_id,
80
- synthesizer_llm_client=synthesizer_llm_client,
81
- trainee_llm_client=trainee_llm_client,
82
- if_web_search=config['web_search'],
83
- tokenizer_instance=Tokenizer(
84
- model_name=config['tokenizer']
85
  ),
86
- traverse_strategy=traverse_strategy
87
  )
88
 
89
- graph_gen.insert(data, config['data_type'])
90
-
91
- graph_gen.quiz(max_samples=config['quiz_samples'])
92
-
93
- graph_gen.judge(re_judge=config["re_judge"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- graph_gen.traverse()
 
 
96
 
97
- path = os.path.join(working_dir, "data", "graphgen", str(unique_id), f"config-{unique_id}.yaml")
98
- save_config(path, config)
99
 
100
- if __name__ == '__main__':
101
  main()
 
1
+ import argparse
2
  import os
 
3
  import time
 
4
  from importlib.resources import files
5
+
6
  import yaml
7
  from dotenv import load_dotenv
8
 
9
  from .graphgen import GraphGen
10
+ from .utils import logger, set_logger
 
11
 
12
  sys_path = os.path.abspath(os.path.dirname(__file__))
13
 
14
  load_dotenv()
15
 
16
+
17
  def set_working_dir(folder):
18
  os.makedirs(folder, exist_ok=True)
19
  os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
20
  os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
21
 
22
+
23
  def save_config(config_path, global_config):
24
  if not os.path.exists(os.path.dirname(config_path)):
25
  os.makedirs(os.path.dirname(config_path))
26
+ with open(config_path, "w", encoding="utf-8") as config_file:
27
+ yaml.dump(
28
+ global_config, config_file, default_flow_style=False, allow_unicode=True
29
+ )
30
+
31
 
32
  def main():
33
  parser = argparse.ArgumentParser()
34
+ parser.add_argument(
35
+ "--config_file",
36
+ help="Config parameters for GraphGen.",
37
+ default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
38
+ type=str,
39
+ )
40
+ parser.add_argument(
41
+ "--output_dir",
42
+ help="Output directory for GraphGen.",
43
+ default=sys_path,
44
+ required=True,
45
+ type=str,
46
+ )
47
 
48
  args = parser.parse_args()
49
 
50
  working_dir = args.output_dir
51
  set_working_dir(working_dir)
 
 
52
 
53
+ with open(args.config_file, "r", encoding="utf-8") as f:
54
  config = yaml.load(f, Loader=yaml.FullLoader)
55
 
56
+ output_data_type = config["output_data_type"]
57
+ unique_id = int(time.time())
58
+ set_logger(
59
+ os.path.join(
60
+ working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
61
+ ),
62
+ if_stream=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
+ logger.info(
65
+ "GraphGen with unique ID %s logging to %s",
66
+ unique_id,
67
+ os.path.join(
68
+ working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
 
 
 
 
69
  ),
 
70
  )
71
 
72
+ graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
73
+
74
+ graph_gen.insert()
75
+
76
+ if config["search"]["enabled"]:
77
+ graph_gen.search()
78
+
79
+ # Use pipeline according to the output data type
80
+ if output_data_type in ["atomic", "aggregated", "multi_hop"]:
81
+ if "quiz_and_judge_strategy" in config and config[
82
+ "quiz_and_judge_strategy"
83
+ ].get("enabled", False):
84
+ graph_gen.quiz()
85
+ graph_gen.judge()
86
+ else:
87
+ logger.warning(
88
+ "Quiz and Judge strategy is disabled. Edge sampling falls back to random."
89
+ )
90
+ graph_gen.traverse_strategy.edge_sampling = "random"
91
+ graph_gen.traverse()
92
+ elif output_data_type == "cot":
93
+ graph_gen.generate_reasoning(method_params=config["method_params"])
94
+ else:
95
+ raise ValueError(f"Unsupported output data type: {output_data_type}")
96
 
97
+ output_path = os.path.join(working_dir, "data", "graphgen", str(unique_id))
98
+ save_config(os.path.join(output_path, f"config-{unique_id}.yaml"), config)
99
+ logger.info("GraphGen completed successfully. Data saved to %s", output_path)
100
 
 
 
101
 
102
+ if __name__ == "__main__":
103
  main()
graphgen/graphgen.py CHANGED
@@ -1,10 +1,8 @@
1
- # Adapt from https://github.com/HKUDS/LightRAG
2
-
3
  import asyncio
4
  import os
5
  import time
6
  from dataclasses import dataclass, field
7
- from typing import List, Union, cast
8
 
9
  import gradio as gr
10
  from tqdm.asyncio import tqdm as tqdm_async
@@ -12,85 +10,124 @@ from tqdm.asyncio import tqdm as tqdm_async
12
  from .models import (
13
  Chunk,
14
  JsonKVStorage,
 
15
  NetworkXStorage,
16
  OpenAIModel,
17
  Tokenizer,
18
  TraverseStrategy,
19
- WikiSearch,
20
  )
21
  from .models.storage.base_storage import StorageNameSpace
22
  from .operators import (
23
  extract_kg,
 
24
  judge_statement,
25
  quiz,
26
- search_wikipedia,
27
- skip_judge_statement,
28
  traverse_graph_atomically,
29
  traverse_graph_by_edge,
30
  traverse_graph_for_multi_hop,
31
  )
32
- from .utils import compute_content_hash, create_event_loop, logger
 
 
 
 
 
 
33
 
34
  sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
35
 
 
36
  @dataclass
37
  class GraphGen:
38
  unique_id: int = int(time.time())
39
  working_dir: str = os.path.join(sys_path, "cache")
40
-
41
- # text chunking
42
- chunk_size: int = 1024
43
- chunk_overlap_size: int = 100
44
 
45
  # llm
 
46
  synthesizer_llm_client: OpenAIModel = None
47
  trainee_llm_client: OpenAIModel = None
48
- tokenizer_instance: Tokenizer = None
49
 
50
- # web search
51
- if_web_search: bool = False
52
- wiki_client: WikiSearch = field(default_factory=WikiSearch)
 
 
 
 
 
 
53
 
54
- # traverse strategy
55
- traverse_strategy: TraverseStrategy = field(default_factory=TraverseStrategy)
56
 
57
  # webui
58
  progress_bar: gr.Progress = None
59
 
60
  def __post_init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  self.full_docs_storage: JsonKVStorage = JsonKVStorage(
62
  self.working_dir, namespace="full_docs"
63
  )
64
  self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
65
  self.working_dir, namespace="text_chunks"
66
  )
67
- self.wiki_storage: JsonKVStorage = JsonKVStorage(
68
- self.working_dir, namespace="wiki"
69
- )
70
  self.graph_storage: NetworkXStorage = NetworkXStorage(
71
  self.working_dir, namespace="graph"
72
  )
 
 
 
73
  self.rephrase_storage: JsonKVStorage = JsonKVStorage(
74
  self.working_dir, namespace="rephrase"
75
  )
76
- self.qa_storage: JsonKVStorage = JsonKVStorage(
77
- os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)), namespace=f"qa-{self.unique_id}"
 
78
  )
79
 
80
- async def async_split_chunks(self, data: Union[List[list], List[dict]], data_type: str) -> dict:
81
- # TODO: 是否进行指代消解
 
 
82
  if len(data) == 0:
83
  return {}
84
 
85
- new_docs = {}
86
  inserting_chunks = {}
87
  if data_type == "raw":
88
  assert isinstance(data, list) and isinstance(data[0], dict)
89
  # compute hash for each document
90
  new_docs = {
91
- compute_content_hash(doc['content'], prefix="doc-"): {'content': doc['content']} for doc in data
 
 
 
92
  }
93
- _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
 
 
94
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
95
  if len(new_docs) == 0:
96
  logger.warning("All docs are already in the storage")
@@ -100,63 +137,83 @@ class GraphGen:
100
  cur_index = 1
101
  doc_number = len(new_docs)
102
  async for doc_key, doc in tqdm_async(
103
- new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
104
- ):
105
  chunks = {
106
  compute_content_hash(dp["content"], prefix="chunk-"): {
107
  **dp,
108
- 'full_doc_id': doc_key
109
- } for dp in self.tokenizer_instance.chunk_by_token_size(doc["content"],
110
- self.chunk_overlap_size, self.chunk_size)
 
 
111
  }
112
  inserting_chunks.update(chunks)
113
 
114
  if self.progress_bar is not None:
115
- self.progress_bar(
116
- cur_index / doc_number, f"Chunking {doc_key}"
117
- )
118
  cur_index += 1
119
 
120
- _add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys()))
121
- inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys}
 
 
 
 
122
  elif data_type == "chunked":
123
  assert isinstance(data, list) and isinstance(data[0], list)
124
  new_docs = {
125
- compute_content_hash("".join(chunk['content']), prefix="doc-"): {'content': "".join(chunk['content'])}
126
- for doc in data for chunk in doc
 
 
 
127
  }
128
- _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
 
 
129
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
130
  if len(new_docs) == 0:
131
  logger.warning("All docs are already in the storage")
132
  return {}
133
  logger.info("[New Docs] inserting %d docs", len(new_docs))
134
- async for doc in tqdm_async(data, desc="[1/4]Chunking documents", unit="doc"):
135
- doc_str = "".join([chunk['content'] for chunk in doc])
 
 
136
  for chunk in doc:
137
- chunk_key = compute_content_hash(chunk['content'], prefix="chunk-")
138
  inserting_chunks[chunk_key] = {
139
  **chunk,
140
- 'full_doc_id': compute_content_hash(doc_str, prefix="doc-")
141
  }
142
- _add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys()))
143
- inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys}
 
 
 
 
 
 
144
 
145
  await self.full_docs_storage.upsert(new_docs)
146
  await self.text_chunks_storage.upsert(inserting_chunks)
147
 
148
  return inserting_chunks
149
 
150
- def insert(self, data: Union[List[list], List[dict]], data_type: str):
151
  loop = create_event_loop()
152
- loop.run_until_complete(self.async_insert(data, data_type))
153
 
154
- async def async_insert(self, data: Union[List[list], List[dict]], data_type: str):
155
  """
156
-
157
  insert chunks into the graph
158
  """
159
 
 
 
 
 
160
  inserting_chunks = await self.async_split_chunks(data, data_type)
161
 
162
  if len(inserting_chunks) == 0:
@@ -169,52 +226,96 @@ class GraphGen:
169
  llm_client=self.synthesizer_llm_client,
170
  kg_instance=self.graph_storage,
171
  tokenizer_instance=self.tokenizer_instance,
172
- chunks=[Chunk(id=k, content=v['content']) for k, v in inserting_chunks.items()],
173
- progress_bar = self.progress_bar,
 
 
174
  )
175
  if not _add_entities_and_relations:
176
  logger.warning("No entities or relations extracted")
177
  return
178
 
179
- logger.info("[Wiki Search] is %s", 'enabled' if self.if_web_search else 'disabled')
180
- if self.if_web_search:
181
- logger.info("[Wiki Search]...")
182
- _add_wiki_data = await search_wikipedia(
183
- llm_client= self.synthesizer_llm_client,
184
- wiki_search_client=self.wiki_client,
185
- knowledge_graph_instance=_add_entities_and_relations
186
- )
187
- await self.wiki_storage.upsert(_add_wiki_data)
188
-
189
  await self._insert_done()
190
 
191
  async def _insert_done(self):
192
  tasks = []
193
- for storage_instance in [self.full_docs_storage, self.text_chunks_storage,
194
- self.graph_storage, self.wiki_storage]:
 
 
 
 
195
  if storage_instance is None:
196
  continue
197
  tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
198
  await asyncio.gather(*tasks)
199
 
200
- def quiz(self, max_samples=1):
201
  loop = create_event_loop()
202
- loop.run_until_complete(self.async_quiz(max_samples))
203
 
204
- async def async_quiz(self, max_samples=1):
205
- await quiz(self.synthesizer_llm_client, self.graph_storage, self.rephrase_storage, max_samples)
206
- await self.rephrase_storage.index_done_callback()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- def judge(self, re_judge=False, skip=False):
209
  loop = create_event_loop()
210
- loop.run_until_complete(self.async_judge(re_judge, skip))
 
 
 
 
 
 
 
 
 
 
211
 
212
- async def async_judge(self, re_judge=False, skip=False):
213
- if skip:
214
- _update_relations = await skip_judge_statement(self.graph_storage)
215
- else:
216
- _update_relations = await judge_statement(self.trainee_llm_client, self.graph_storage,
217
- self.rephrase_storage, re_judge)
 
 
 
 
 
 
218
  await _update_relations.index_done_callback()
219
 
220
  def traverse(self):
@@ -222,26 +323,60 @@ class GraphGen:
222
  loop.run_until_complete(self.async_traverse())
223
 
224
  async def async_traverse(self):
225
- if self.traverse_strategy.qa_form == "atomic":
226
- results = await traverse_graph_atomically(self.synthesizer_llm_client,
227
- self.tokenizer_instance,
228
- self.graph_storage,
229
- self.traverse_strategy,
230
- self.text_chunks_storage,
231
- self.progress_bar)
232
- elif self.traverse_strategy.qa_form == "multi_hop":
233
- results = await traverse_graph_for_multi_hop(self.synthesizer_llm_client,
234
- self.tokenizer_instance,
235
- self.graph_storage,
236
- self.traverse_strategy,
237
- self.text_chunks_storage,
238
- self.progress_bar)
239
- elif self.traverse_strategy.qa_form == "aggregated":
240
- results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance,
241
- self.graph_storage, self.traverse_strategy, self.text_chunks_storage,
242
- self.progress_bar)
 
 
 
 
 
 
 
 
 
 
 
243
  else:
244
- raise ValueError(f"Unknown qa_form: {self.traverse_strategy.qa_form}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  await self.qa_storage.upsert(results)
246
  await self.qa_storage.index_done_callback()
247
 
@@ -252,7 +387,7 @@ class GraphGen:
252
  async def async_clear(self):
253
  await self.full_docs_storage.drop()
254
  await self.text_chunks_storage.drop()
255
- await self.wiki_storage.drop()
256
  await self.graph_storage.clear()
257
  await self.rephrase_storage.drop()
258
  await self.qa_storage.drop()
 
 
 
1
  import asyncio
2
  import os
3
  import time
4
  from dataclasses import dataclass, field
5
+ from typing import Dict, List, Union, cast
6
 
7
  import gradio as gr
8
  from tqdm.asyncio import tqdm as tqdm_async
 
10
  from .models import (
11
  Chunk,
12
  JsonKVStorage,
13
+ JsonListStorage,
14
  NetworkXStorage,
15
  OpenAIModel,
16
  Tokenizer,
17
  TraverseStrategy,
 
18
  )
19
  from .models.storage.base_storage import StorageNameSpace
20
  from .operators import (
21
  extract_kg,
22
+ generate_cot,
23
  judge_statement,
24
  quiz,
25
+ search_all,
 
26
  traverse_graph_atomically,
27
  traverse_graph_by_edge,
28
  traverse_graph_for_multi_hop,
29
  )
30
+ from .utils import (
31
+ compute_content_hash,
32
+ create_event_loop,
33
+ format_generation_results,
34
+ logger,
35
+ read_file,
36
+ )
37
 
38
  sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
39
 
40
+
41
  @dataclass
42
  class GraphGen:
43
  unique_id: int = int(time.time())
44
  working_dir: str = os.path.join(sys_path, "cache")
45
+ config: Dict = field(default_factory=dict)
 
 
 
46
 
47
  # llm
48
+ tokenizer_instance: Tokenizer = None
49
  synthesizer_llm_client: OpenAIModel = None
50
  trainee_llm_client: OpenAIModel = None
 
51
 
52
+ # text chunking
53
+ # TODO: make it configurable
54
+ chunk_size: int = 1024
55
+ chunk_overlap_size: int = 100
56
+
57
+ # search
58
+ search_config: dict = field(
59
+ default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
60
+ )
61
 
62
+ # traversal
63
+ traverse_strategy: TraverseStrategy = None
64
 
65
  # webui
66
  progress_bar: gr.Progress = None
67
 
68
  def __post_init__(self):
69
+ self.tokenizer_instance: Tokenizer = Tokenizer(
70
+ model_name=self.config["tokenizer"]
71
+ )
72
+ self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
73
+ model_name=os.getenv("SYNTHESIZER_MODEL"),
74
+ api_key=os.getenv("SYNTHESIZER_API_KEY"),
75
+ base_url=os.getenv("SYNTHESIZER_BASE_URL"),
76
+ tokenizer_instance=self.tokenizer_instance,
77
+ )
78
+ self.trainee_llm_client: OpenAIModel = OpenAIModel(
79
+ model_name=os.getenv("TRAINEE_MODEL"),
80
+ api_key=os.getenv("TRAINEE_API_KEY"),
81
+ base_url=os.getenv("TRAINEE_BASE_URL"),
82
+ tokenizer_instance=self.tokenizer_instance,
83
+ )
84
+ self.search_config = self.config["search"]
85
+
86
+ if "traverse_strategy" in self.config:
87
+ self.traverse_strategy = TraverseStrategy(
88
+ **self.config["traverse_strategy"]
89
+ )
90
+
91
  self.full_docs_storage: JsonKVStorage = JsonKVStorage(
92
  self.working_dir, namespace="full_docs"
93
  )
94
  self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
95
  self.working_dir, namespace="text_chunks"
96
  )
 
 
 
97
  self.graph_storage: NetworkXStorage = NetworkXStorage(
98
  self.working_dir, namespace="graph"
99
  )
100
+ self.search_storage: JsonKVStorage = JsonKVStorage(
101
+ self.working_dir, namespace="search"
102
+ )
103
  self.rephrase_storage: JsonKVStorage = JsonKVStorage(
104
  self.working_dir, namespace="rephrase"
105
  )
106
+ self.qa_storage: JsonListStorage = JsonListStorage(
107
+ os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)),
108
+ namespace=f"qa-{self.unique_id}",
109
  )
110
 
111
+ async def async_split_chunks(
112
+ self, data: List[Union[List, Dict]], data_type: str
113
+ ) -> dict:
114
+ # TODO: configurable whether to use coreference resolution
115
  if len(data) == 0:
116
  return {}
117
 
 
118
  inserting_chunks = {}
119
  if data_type == "raw":
120
  assert isinstance(data, list) and isinstance(data[0], dict)
121
  # compute hash for each document
122
  new_docs = {
123
+ compute_content_hash(doc["content"], prefix="doc-"): {
124
+ "content": doc["content"]
125
+ }
126
+ for doc in data
127
  }
128
+ _add_doc_keys = await self.full_docs_storage.filter_keys(
129
+ list(new_docs.keys())
130
+ )
131
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
132
  if len(new_docs) == 0:
133
  logger.warning("All docs are already in the storage")
 
137
  cur_index = 1
138
  doc_number = len(new_docs)
139
  async for doc_key, doc in tqdm_async(
140
+ new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
141
+ ):
142
  chunks = {
143
  compute_content_hash(dp["content"], prefix="chunk-"): {
144
  **dp,
145
+ "full_doc_id": doc_key,
146
+ }
147
+ for dp in self.tokenizer_instance.chunk_by_token_size(
148
+ doc["content"], self.chunk_overlap_size, self.chunk_size
149
+ )
150
  }
151
  inserting_chunks.update(chunks)
152
 
153
  if self.progress_bar is not None:
154
+ self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
 
 
155
  cur_index += 1
156
 
157
+ _add_chunk_keys = await self.text_chunks_storage.filter_keys(
158
+ list(inserting_chunks.keys())
159
+ )
160
+ inserting_chunks = {
161
+ k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
162
+ }
163
  elif data_type == "chunked":
164
  assert isinstance(data, list) and isinstance(data[0], list)
165
  new_docs = {
166
+ compute_content_hash("".join(chunk["content"]), prefix="doc-"): {
167
+ "content": "".join(chunk["content"])
168
+ }
169
+ for doc in data
170
+ for chunk in doc
171
  }
172
+ _add_doc_keys = await self.full_docs_storage.filter_keys(
173
+ list(new_docs.keys())
174
+ )
175
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
176
  if len(new_docs) == 0:
177
  logger.warning("All docs are already in the storage")
178
  return {}
179
  logger.info("[New Docs] inserting %d docs", len(new_docs))
180
+ async for doc in tqdm_async(
181
+ data, desc="[1/4]Chunking documents", unit="doc"
182
+ ):
183
+ doc_str = "".join([chunk["content"] for chunk in doc])
184
  for chunk in doc:
185
+ chunk_key = compute_content_hash(chunk["content"], prefix="chunk-")
186
  inserting_chunks[chunk_key] = {
187
  **chunk,
188
+ "full_doc_id": compute_content_hash(doc_str, prefix="doc-"),
189
  }
190
+ _add_chunk_keys = await self.text_chunks_storage.filter_keys(
191
+ list(inserting_chunks.keys())
192
+ )
193
+ inserting_chunks = {
194
+ k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
195
+ }
196
+ else:
197
+ raise ValueError(f"Unknown data type: {data_type}")
198
 
199
  await self.full_docs_storage.upsert(new_docs)
200
  await self.text_chunks_storage.upsert(inserting_chunks)
201
 
202
  return inserting_chunks
203
 
204
+ def insert(self):
205
  loop = create_event_loop()
206
+ loop.run_until_complete(self.async_insert())
207
 
208
+ async def async_insert(self):
209
  """
 
210
  insert chunks into the graph
211
  """
212
 
213
+ input_file = self.config["input_file"]
214
+ data_type = self.config["input_data_type"]
215
+ data = read_file(input_file)
216
+
217
  inserting_chunks = await self.async_split_chunks(data, data_type)
218
 
219
  if len(inserting_chunks) == 0:
 
226
  llm_client=self.synthesizer_llm_client,
227
  kg_instance=self.graph_storage,
228
  tokenizer_instance=self.tokenizer_instance,
229
+ chunks=[
230
+ Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
231
+ ],
232
+ progress_bar=self.progress_bar,
233
  )
234
  if not _add_entities_and_relations:
235
  logger.warning("No entities or relations extracted")
236
  return
237
 
 
 
 
 
 
 
 
 
 
 
238
  await self._insert_done()
239
 
240
  async def _insert_done(self):
241
  tasks = []
242
+ for storage_instance in [
243
+ self.full_docs_storage,
244
+ self.text_chunks_storage,
245
+ self.graph_storage,
246
+ self.search_storage,
247
+ ]:
248
  if storage_instance is None:
249
  continue
250
  tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
251
  await asyncio.gather(*tasks)
252
 
253
+ def search(self):
254
  loop = create_event_loop()
255
+ loop.run_until_complete(self.async_search())
256
 
257
+ async def async_search(self):
258
+ logger.info(
259
+ "Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
260
+ )
261
+ if self.search_config["enabled"]:
262
+ logger.info(
263
+ "[Search] %s ...", ", ".join(self.search_config["search_types"])
264
+ )
265
+ all_nodes = await self.graph_storage.get_all_nodes()
266
+ all_nodes_names = [node[0] for node in all_nodes]
267
+ new_search_entities = await self.full_docs_storage.filter_keys(
268
+ all_nodes_names
269
+ )
270
+ logger.info(
271
+ "[Search] Found %d entities to search", len(new_search_entities)
272
+ )
273
+ _add_search_data = await search_all(
274
+ search_types=self.search_config["search_types"],
275
+ search_entities=new_search_entities,
276
+ )
277
+ if _add_search_data:
278
+ await self.search_storage.upsert(_add_search_data)
279
+ logger.info("[Search] %d entities searched", len(_add_search_data))
280
+
281
+ # Format search results for inserting
282
+ search_results = []
283
+ for _, search_data in _add_search_data.items():
284
+ search_results.extend(
285
+ [
286
+ {"content": search_data[key]}
287
+ for key in list(search_data.keys())
288
+ ]
289
+ )
290
+ # TODO: fix insert after search
291
+ await self.async_insert()
292
 
293
+ def quiz(self):
294
  loop = create_event_loop()
295
+ loop.run_until_complete(self.async_quiz())
296
+
297
+ async def async_quiz(self):
298
+ max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
299
+ await quiz(
300
+ self.synthesizer_llm_client,
301
+ self.graph_storage,
302
+ self.rephrase_storage,
303
+ max_samples,
304
+ )
305
+ await self.rephrase_storage.index_done_callback()
306
 
307
+ def judge(self):
308
+ loop = create_event_loop()
309
+ loop.run_until_complete(self.async_judge())
310
+
311
+ async def async_judge(self):
312
+ re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
313
+ _update_relations = await judge_statement(
314
+ self.trainee_llm_client,
315
+ self.graph_storage,
316
+ self.rephrase_storage,
317
+ re_judge,
318
+ )
319
  await _update_relations.index_done_callback()
320
 
321
  def traverse(self):
 
323
  loop.run_until_complete(self.async_traverse())
324
 
325
  async def async_traverse(self):
326
+ output_data_type = self.config["output_data_type"]
327
+
328
+ if output_data_type == "atomic":
329
+ results = await traverse_graph_atomically(
330
+ self.synthesizer_llm_client,
331
+ self.tokenizer_instance,
332
+ self.graph_storage,
333
+ self.traverse_strategy,
334
+ self.text_chunks_storage,
335
+ self.progress_bar,
336
+ )
337
+ elif output_data_type == "multi_hop":
338
+ results = await traverse_graph_for_multi_hop(
339
+ self.synthesizer_llm_client,
340
+ self.tokenizer_instance,
341
+ self.graph_storage,
342
+ self.traverse_strategy,
343
+ self.text_chunks_storage,
344
+ self.progress_bar,
345
+ )
346
+ elif output_data_type == "aggregated":
347
+ results = await traverse_graph_by_edge(
348
+ self.synthesizer_llm_client,
349
+ self.tokenizer_instance,
350
+ self.graph_storage,
351
+ self.traverse_strategy,
352
+ self.text_chunks_storage,
353
+ self.progress_bar,
354
+ )
355
  else:
356
+ raise ValueError(f"Unknown qa_form: {output_data_type}")
357
+
358
+ results = format_generation_results(
359
+ results, output_data_format=self.config["output_data_format"]
360
+ )
361
+
362
+ await self.qa_storage.upsert(results)
363
+ await self.qa_storage.index_done_callback()
364
+
365
+ def generate_reasoning(self, method_params):
366
+ loop = create_event_loop()
367
+ loop.run_until_complete(self.async_generate_reasoning(method_params))
368
+
369
+ async def async_generate_reasoning(self, method_params):
370
+ results = await generate_cot(
371
+ self.graph_storage,
372
+ self.synthesizer_llm_client,
373
+ method_params=method_params,
374
+ )
375
+
376
+ results = format_generation_results(
377
+ results, output_data_format=self.config["output_data_format"]
378
+ )
379
+
380
  await self.qa_storage.upsert(results)
381
  await self.qa_storage.index_done_callback()
382
 
 
387
  async def async_clear(self):
388
  await self.full_docs_storage.drop()
389
  await self.text_chunks_storage.drop()
390
+ await self.search_storage.drop()
391
  await self.graph_storage.clear()
392
  await self.rephrase_storage.drop()
393
  await self.qa_storage.drop()
graphgen/models/__init__.py CHANGED
@@ -1,22 +1,20 @@
1
- from .text.chunk import Chunk
2
- from .text.text_pair import TextPair
3
-
4
- from .llm.topk_token_model import Token, TopkTokenModel
5
- from .llm.openai_model import OpenAIModel
6
- from .llm.tokenizer import Tokenizer
7
-
8
- from .storage.networkx_storage import NetworkXStorage
9
- from .storage.json_storage import JsonKVStorage
10
-
11
- from .search.wiki_search import WikiSearch
12
-
13
  from .evaluate.length_evaluator import LengthEvaluator
14
  from .evaluate.mtld_evaluator import MTLDEvaluator
15
  from .evaluate.reward_evaluator import RewardEvaluator
16
  from .evaluate.uni_evaluator import UniEvaluator
17
-
 
 
 
 
 
 
 
 
18
  from .strategy.travserse_strategy import TraverseStrategy
19
-
 
20
 
21
  __all__ = [
22
  # llm models
@@ -28,8 +26,12 @@ __all__ = [
28
  "Chunk",
29
  "NetworkXStorage",
30
  "JsonKVStorage",
 
31
  # search models
32
  "WikiSearch",
 
 
 
33
  # evaluate models
34
  "TextPair",
35
  "LengthEvaluator",
@@ -38,4 +40,6 @@ __all__ = [
38
  "UniEvaluator",
39
  # strategy models
40
  "TraverseStrategy",
 
 
41
  ]
 
1
+ from .community.community_detector import CommunityDetector
 
 
 
 
 
 
 
 
 
 
 
2
  from .evaluate.length_evaluator import LengthEvaluator
3
  from .evaluate.mtld_evaluator import MTLDEvaluator
4
  from .evaluate.reward_evaluator import RewardEvaluator
5
  from .evaluate.uni_evaluator import UniEvaluator
6
+ from .llm.openai_model import OpenAIModel
7
+ from .llm.tokenizer import Tokenizer
8
+ from .llm.topk_token_model import Token, TopkTokenModel
9
+ from .search.db.uniprot_search import UniProtSearch
10
+ from .search.kg.wiki_search import WikiSearch
11
+ from .search.web.bing_search import BingSearch
12
+ from .search.web.google_search import GoogleSearch
13
+ from .storage.json_storage import JsonKVStorage, JsonListStorage
14
+ from .storage.networkx_storage import NetworkXStorage
15
  from .strategy.travserse_strategy import TraverseStrategy
16
+ from .text.chunk import Chunk
17
+ from .text.text_pair import TextPair
18
 
19
  __all__ = [
20
  # llm models
 
26
  "Chunk",
27
  "NetworkXStorage",
28
  "JsonKVStorage",
29
+ "JsonListStorage",
30
  # search models
31
  "WikiSearch",
32
+ "GoogleSearch",
33
+ "BingSearch",
34
+ "UniProtSearch",
35
  # evaluate models
36
  "TextPair",
37
  "LengthEvaluator",
 
40
  "UniEvaluator",
41
  # strategy models
42
  "TraverseStrategy",
43
+ # community models
44
+ "CommunityDetector",
45
  ]
graphgen/models/community/__init__.py ADDED
File without changes
graphgen/models/community/community_detector.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List
4
+
5
+ from graphgen.models.storage.networkx_storage import NetworkXStorage
6
+
7
+
8
+ @dataclass
9
+ class CommunityDetector:
10
+ """Class for community detection algorithms."""
11
+
12
+ graph_storage: NetworkXStorage = None
13
+ method: str = "leiden"
14
+ method_params: Dict[str, Any] = None
15
+
16
+ async def detect_communities(self) -> Dict[str, int]:
17
+ if self.method == "leiden":
18
+ return await self._leiden_communities(**self.method_params or {})
19
+ raise ValueError(f"Unknown community detection method: {self.method}")
20
+
21
+ async def get_graph(self):
22
+ return await self.graph_storage.get_graph()
23
+
24
+ async def _leiden_communities(
25
+ self, max_size: int = None, **kwargs
26
+ ) -> Dict[str, int]:
27
+ """
28
+ Detect communities using the Leiden algorithm.
29
+ If max_size is given, any community larger than max_size will be split
30
+ into smaller sub-communities each having at most max_size nodes.
31
+ """
32
+ import igraph as ig
33
+ import networkx as nx
34
+ from leidenalg import ModularityVertexPartition, find_partition
35
+
36
+ graph = await self.get_graph()
37
+ graph.remove_nodes_from(list(nx.isolates(graph)))
38
+
39
+ ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
40
+
41
+ random_seed = kwargs.get("random_seed", 42)
42
+ use_lcc = kwargs.get("use_lcc", False)
43
+
44
+ communities: Dict[str, int] = {}
45
+ if use_lcc:
46
+ lcc = ig_graph.components().giant()
47
+ partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
48
+ for part, cluster in enumerate(partition):
49
+ for v in cluster:
50
+ communities[lcc.vs[v]["name"]] = part
51
+ else:
52
+ offset = 0
53
+ for component in ig_graph.components():
54
+ subgraph = ig_graph.induced_subgraph(component)
55
+ partition = find_partition(
56
+ subgraph, ModularityVertexPartition, seed=random_seed
57
+ )
58
+ for part, cluster in enumerate(partition):
59
+ for v in cluster:
60
+ original_node = subgraph.vs[v]["name"]
61
+ communities[original_node] = part + offset
62
+ offset += len(partition)
63
+
64
+ # split large communities if max_size is specified
65
+ if max_size is None or max_size <= 0:
66
+ return communities
67
+
68
+ return await self._split_communities(communities, max_size)
69
+
70
+ @staticmethod
71
+ async def _split_communities(
72
+ communities: Dict[str, int], max_size: int
73
+ ) -> Dict[str, int]:
74
+ """
75
+ Split communities larger than max_size into smaller sub-communities.
76
+ """
77
+ cid2nodes: Dict[int, List[str]] = defaultdict(list)
78
+ for node, cid in communities.items():
79
+ cid2nodes[cid].append(node)
80
+
81
+ new_communities: Dict[str, int] = {}
82
+ new_cid = 0
83
+ for cid, nodes in cid2nodes.items():
84
+ if len(nodes) <= max_size:
85
+ for n in nodes:
86
+ new_communities[n] = new_cid
87
+ new_cid += 1
88
+ else:
89
+ for start in range(0, len(nodes), max_size):
90
+ sub = nodes[start : start + max_size]
91
+ for n in sub:
92
+ new_communities[n] = new_cid
93
+ new_cid += 1
94
+
95
+ return new_communities
graphgen/models/llm/openai_model.py CHANGED
@@ -1,18 +1,21 @@
1
  import math
 
2
  from dataclasses import dataclass, field
3
- from typing import List, Dict, Optional
 
4
  import openai
5
- from openai import AsyncOpenAI, RateLimitError, APIConnectionError, APITimeoutError
6
  from tenacity import (
7
  retry,
 
8
  stop_after_attempt,
9
  wait_exponential,
10
- retry_if_exception_type,
11
  )
12
 
13
- from graphgen.models.llm.topk_token_model import TopkTokenModel, Token
14
- from graphgen.models.llm.tokenizer import Tokenizer
15
  from graphgen.models.llm.limitter import RPM, TPM
 
 
 
16
 
17
  def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
18
  token_logprobs = response.choices[0].logprobs.content
@@ -20,13 +23,23 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
20
  for token_prob in token_logprobs:
21
  prob = math.exp(token_prob.logprob)
22
  candidate_tokens = [
23
- Token(t.token, math.exp(t.logprob))
24
- for t in token_prob.top_logprobs
25
  ]
26
  token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
27
  tokens.append(token)
28
  return tokens
29
 
 
 
 
 
 
 
 
 
 
 
 
30
  @dataclass
31
  class OpenAIModel(TopkTokenModel):
32
  model_name: str = "gpt-4o-mini"
@@ -42,12 +55,13 @@ class OpenAIModel(TopkTokenModel):
42
  rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
43
  tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
44
 
 
45
 
46
  def __post_init__(self):
47
  assert self.api_key is not None, "Please provide api key to access openai api."
48
- if self.api_key == "":
49
- self.api_key = "none"
50
- self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
51
 
52
  def _pre_generate(self, text: str, history: List[str]) -> Dict:
53
  kwargs = {
@@ -69,16 +83,19 @@ class OpenAIModel(TopkTokenModel):
69
  assert len(history) % 2 == 0, "History should have even number of elements."
70
  messages = history + messages
71
 
72
- kwargs['messages']= messages
73
  return kwargs
74
 
75
-
76
  @retry(
77
  stop=stop_after_attempt(5),
78
  wait=wait_exponential(multiplier=1, min=4, max=10),
79
- retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
 
 
80
  )
81
- async def generate_topk_per_token(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
 
 
82
  kwargs = self._pre_generate(text, history)
83
  if self.topk_per_token > 0:
84
  kwargs["logprobs"] = True
@@ -87,9 +104,8 @@ class OpenAIModel(TopkTokenModel):
87
  # Limit max_tokens to 1 to avoid long completions
88
  kwargs["max_tokens"] = 1
89
 
90
- completion = await self.client.chat.completions.create( # pylint: disable=E1125
91
- model=self.model_name,
92
- **kwargs
93
  )
94
 
95
  tokens = get_top_response_tokens(completion)
@@ -99,32 +115,41 @@ class OpenAIModel(TopkTokenModel):
99
  @retry(
100
  stop=stop_after_attempt(5),
101
  wait=wait_exponential(multiplier=1, min=4, max=10),
102
- retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
 
 
103
  )
104
- async def generate_answer(self, text: str, history: Optional[List[str]] = None, temperature: int = 0) -> str:
 
 
105
  kwargs = self._pre_generate(text, history)
106
  kwargs["temperature"] = temperature
107
 
108
  prompt_tokens = 0
109
- for message in kwargs['messages']:
110
- prompt_tokens += len(Tokenizer().encode_string(message['content']))
111
- estimated_tokens = prompt_tokens + kwargs['max_tokens']
 
 
112
 
113
  if self.request_limit:
114
  await self.rpm.wait(silent=True)
115
  await self.tpm.wait(estimated_tokens, silent=True)
116
 
117
- completion = await self.client.chat.completions.create( # pylint: disable=E1125
118
- model=self.model_name,
119
- **kwargs
120
  )
121
  if hasattr(completion, "usage"):
122
- self.token_usage.append({
123
- "prompt_tokens": completion.usage.prompt_tokens,
124
- "completion_tokens": completion.usage.completion_tokens,
125
- "total_tokens": completion.usage.total_tokens,
126
- })
127
- return completion.choices[0].message.content
128
-
129
- async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
 
 
 
 
130
  raise NotImplementedError
 
1
  import math
2
+ import re
3
  from dataclasses import dataclass, field
4
+ from typing import Dict, List, Optional
5
+
6
  import openai
7
+ from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
8
  from tenacity import (
9
  retry,
10
+ retry_if_exception_type,
11
  stop_after_attempt,
12
  wait_exponential,
 
13
  )
14
 
 
 
15
  from graphgen.models.llm.limitter import RPM, TPM
16
+ from graphgen.models.llm.tokenizer import Tokenizer
17
+ from graphgen.models.llm.topk_token_model import Token, TopkTokenModel
18
+
19
 
20
  def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
21
  token_logprobs = response.choices[0].logprobs.content
 
23
  for token_prob in token_logprobs:
24
  prob = math.exp(token_prob.logprob)
25
  candidate_tokens = [
26
+ Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs
 
27
  ]
28
  token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
29
  tokens.append(token)
30
  return tokens
31
 
32
+
33
+ def filter_think_tags(text: str) -> str:
34
+ """
35
+ Remove <think> tags from the text.
36
+ If the text contains <think> and </think>, it removes everything between them and the tags themselves.
37
+ """
38
+ think_pattern = re.compile(r"<think>.*?</think>", re.DOTALL)
39
+ filtered_text = think_pattern.sub("", text).strip()
40
+ return filtered_text if filtered_text else text.strip()
41
+
42
+
43
  @dataclass
44
  class OpenAIModel(TopkTokenModel):
45
  model_name: str = "gpt-4o-mini"
 
55
  rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
56
  tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
57
 
58
+ tokenizer_instance: Tokenizer = field(default_factory=Tokenizer)
59
 
60
  def __post_init__(self):
61
  assert self.api_key is not None, "Please provide api key to access openai api."
62
+ self.client = AsyncOpenAI(
63
+ api_key=self.api_key or "dummy", base_url=self.base_url
64
+ )
65
 
66
  def _pre_generate(self, text: str, history: List[str]) -> Dict:
67
  kwargs = {
 
83
  assert len(history) % 2 == 0, "History should have even number of elements."
84
  messages = history + messages
85
 
86
+ kwargs["messages"] = messages
87
  return kwargs
88
 
 
89
  @retry(
90
  stop=stop_after_attempt(5),
91
  wait=wait_exponential(multiplier=1, min=4, max=10),
92
+ retry=retry_if_exception_type(
93
+ (RateLimitError, APIConnectionError, APITimeoutError)
94
+ ),
95
  )
96
+ async def generate_topk_per_token(
97
+ self, text: str, history: Optional[List[str]] = None
98
+ ) -> List[Token]:
99
  kwargs = self._pre_generate(text, history)
100
  if self.topk_per_token > 0:
101
  kwargs["logprobs"] = True
 
104
  # Limit max_tokens to 1 to avoid long completions
105
  kwargs["max_tokens"] = 1
106
 
107
+ completion = await self.client.chat.completions.create( # pylint: disable=E1125
108
+ model=self.model_name, **kwargs
 
109
  )
110
 
111
  tokens = get_top_response_tokens(completion)
 
115
  @retry(
116
  stop=stop_after_attempt(5),
117
  wait=wait_exponential(multiplier=1, min=4, max=10),
118
+ retry=retry_if_exception_type(
119
+ (RateLimitError, APIConnectionError, APITimeoutError)
120
+ ),
121
  )
122
+ async def generate_answer(
123
+ self, text: str, history: Optional[List[str]] = None, temperature: int = 0
124
+ ) -> str:
125
  kwargs = self._pre_generate(text, history)
126
  kwargs["temperature"] = temperature
127
 
128
  prompt_tokens = 0
129
+ for message in kwargs["messages"]:
130
+ prompt_tokens += len(
131
+ self.tokenizer_instance.encode_string(message["content"])
132
+ )
133
+ estimated_tokens = prompt_tokens + kwargs["max_tokens"]
134
 
135
  if self.request_limit:
136
  await self.rpm.wait(silent=True)
137
  await self.tpm.wait(estimated_tokens, silent=True)
138
 
139
+ completion = await self.client.chat.completions.create( # pylint: disable=E1125
140
+ model=self.model_name, **kwargs
 
141
  )
142
  if hasattr(completion, "usage"):
143
+ self.token_usage.append(
144
+ {
145
+ "prompt_tokens": completion.usage.prompt_tokens,
146
+ "completion_tokens": completion.usage.completion_tokens,
147
+ "total_tokens": completion.usage.total_tokens,
148
+ }
149
+ )
150
+ return filter_think_tags(completion.choices[0].message.content)
151
+
152
+ async def generate_inputs_prob(
153
+ self, text: str, history: Optional[List[str]] = None
154
+ ) -> List[Token]:
155
  raise NotImplementedError
graphgen/models/search/db/__init__.py ADDED
File without changes
graphgen/models/search/db/uniprot_search.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import requests
4
+ from fastapi import HTTPException
5
+
6
+ from graphgen.utils import logger
7
+
8
+ UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
9
+
10
+
11
+ @dataclass
12
+ class UniProtSearch:
13
+ """
14
+ UniProt Search client to search with UniProt.
15
+ 1) Get the protein by accession number.
16
+ 2) Search with keywords or protein names.
17
+ """
18
+
19
+ def get_entry(self, accession: str) -> dict:
20
+ """
21
+ Get the UniProt entry by accession number(e.g., P04637).
22
+ """
23
+ url = f"{UNIPROT_BASE}/{accession}.json"
24
+ return self._safe_get(url).json()
25
+
26
+ def search(
27
+ self,
28
+ query: str,
29
+ *,
30
+ size: int = 10,
31
+ cursor: str = None,
32
+ fields: list[str] = None,
33
+ ) -> dict:
34
+ """
35
+ Search UniProt with a query string.
36
+ :param query: The search query.
37
+ :param size: The number of results to return.
38
+ :param cursor: The cursor for pagination.
39
+ :param fields: The fields to return in the response.
40
+ :return: A dictionary containing the search results.
41
+ """
42
+ params = {
43
+ "query": query,
44
+ "size": size,
45
+ }
46
+ if cursor:
47
+ params["cursor"] = cursor
48
+ if fields:
49
+ params["fields"] = ",".join(fields)
50
+ url = UNIPROT_BASE
51
+ return self._safe_get(url, params=params).json()
52
+
53
+ @staticmethod
54
+ def _safe_get(url: str, params: dict = None) -> requests.Response:
55
+ r = requests.get(
56
+ url,
57
+ params=params,
58
+ headers={"Accept": "application/json"},
59
+ timeout=10,
60
+ )
61
+ if not r.ok:
62
+ logger.error("Search engine error: %s", r.text)
63
+ raise HTTPException(r.status_code, "Search engine error.")
64
+ return r
graphgen/models/search/kg/__init__.py ADDED
File without changes
graphgen/models/search/{wiki_search.py → kg/wiki_search.py} RENAMED
@@ -1,8 +1,9 @@
1
- from typing import List, Union
2
  from dataclasses import dataclass
 
3
 
4
  import wikipedia
5
  from wikipedia import set_lang
 
6
  from graphgen.utils import detect_main_language, logger
7
 
8
 
@@ -13,9 +14,9 @@ class WikiSearch:
13
  assert language in ["en", "zh"], "Only support English and Chinese"
14
  set_lang(language)
15
 
16
- async def search(self, query: str) -> Union[List[str], None]:
17
  self.set_language(detect_main_language(query))
18
- return wikipedia.search(query)
19
 
20
  async def summary(self, query: str) -> Union[str, None]:
21
  self.set_language(detect_main_language(query))
 
 
1
  from dataclasses import dataclass
2
+ from typing import List, Union
3
 
4
  import wikipedia
5
  from wikipedia import set_lang
6
+
7
  from graphgen.utils import detect_main_language, logger
8
 
9
 
 
14
  assert language in ["en", "zh"], "Only support English and Chinese"
15
  set_lang(language)
16
 
17
+ async def search(self, query: str, num_results: int = 1) -> Union[List[str], None]:
18
  self.set_language(detect_main_language(query))
19
+ return wikipedia.search(query, results=num_results, suggestion=False)
20
 
21
  async def summary(self, query: str) -> Union[str, None]:
22
  self.set_language(detect_main_language(query))
graphgen/models/search/web/__init__.py ADDED
File without changes
graphgen/models/search/web/bing_search.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import requests
4
+ from fastapi import HTTPException
5
+
6
+ from graphgen.utils import logger
7
+
8
+ BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
9
+ BING_MKT = "en-US"
10
+
11
+
12
+ @dataclass
13
+ class BingSearch:
14
+ """
15
+ Bing Search client to search with Bing.
16
+ """
17
+
18
+ subscription_key: str
19
+
20
+ def search(self, query: str, num_results: int = 1):
21
+ """
22
+ Search with Bing and return the contexts.
23
+ :param query: The search query.
24
+ :param num_results: The number of results to return.
25
+ :return: A list of search results.
26
+ """
27
+ params = {"q": query, "mkt": BING_MKT, "count": num_results}
28
+ response = requests.get(
29
+ BING_SEARCH_V7_ENDPOINT,
30
+ headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
31
+ params=params,
32
+ timeout=10,
33
+ )
34
+ if not response.ok:
35
+ logger.error("Search engine error: %s", response.text)
36
+ raise HTTPException(response.status_code, "Search engine error.")
37
+ json_content = response.json()
38
+ try:
39
+ contexts = json_content["webPages"]["value"][:num_results]
40
+ except KeyError:
41
+ logger.error("Error encountered: %s", json_content)
42
+ return []
43
+ return contexts
graphgen/models/search/web/google_search.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import requests
4
+ from fastapi import HTTPException
5
+
6
+ from graphgen.utils import logger
7
+
8
+ GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
9
+
10
+
11
+ @dataclass
12
+ class GoogleSearch:
13
+ def __init__(self, subscription_key: str, cx: str):
14
+ """
15
+ Initialize the Google Search client with the subscription key and custom search engine ID.
16
+ :param subscription_key: Your Google API subscription key.
17
+ :param cx: Your custom search engine ID.
18
+ """
19
+ self.subscription_key = subscription_key
20
+ self.cx = cx
21
+
22
+ def search(self, query: str, num_results: int = 1):
23
+ """
24
+ Search with Google and return the contexts.
25
+ :param query: The search query.
26
+ :param num_results: The number of results to return.
27
+ :return: A list of search results.
28
+ """
29
+ params = {
30
+ "key": self.subscription_key,
31
+ "cx": self.cx,
32
+ "q": query,
33
+ "num": num_results,
34
+ }
35
+ response = requests.get(GOOGLE_SEARCH_ENDPOINT, params=params, timeout=10)
36
+ if not response.ok:
37
+ logger.error("Search engine error: %s", response.text)
38
+ raise HTTPException(response.status_code, "Search engine error.")
39
+ json_content = response.json()
40
+ try:
41
+ contexts = json_content["items"][:num_results]
42
+ except KeyError:
43
+ logger.error("Error encountered: %s", json_content)
44
+ return []
45
+ return contexts
graphgen/models/storage/base_storage.py CHANGED
@@ -1,9 +1,11 @@
1
  from dataclasses import dataclass
2
- from typing import Union, Generic, TypeVar
 
3
  from graphgen.models.embed.embedding import EmbeddingFunc
4
 
5
  T = TypeVar("T")
6
 
 
7
  @dataclass
8
  class StorageNameSpace:
9
  working_dir: str = None
@@ -17,9 +19,25 @@ class StorageNameSpace:
17
 
18
 
19
  @dataclass
20
- class BaseKVStorage(Generic[T], StorageNameSpace):
21
- embedding_func: EmbeddingFunc = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
23
  async def all_keys(self) -> list[str]:
24
  raise NotImplementedError
25
 
@@ -41,6 +59,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
41
  async def drop(self):
42
  raise NotImplementedError
43
 
 
44
  @dataclass
45
  class BaseGraphStorage(StorageNameSpace):
46
  embedding_func: EmbeddingFunc = None
@@ -71,7 +90,9 @@ class BaseGraphStorage(StorageNameSpace):
71
  ) -> Union[dict, None]:
72
  raise NotImplementedError
73
 
74
- async def update_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
 
 
75
  raise NotImplementedError
76
 
77
  async def get_all_edges(self) -> Union[list[dict], None]:
 
1
  from dataclasses import dataclass
2
+ from typing import Generic, TypeVar, Union
3
+
4
  from graphgen.models.embed.embedding import EmbeddingFunc
5
 
6
  T = TypeVar("T")
7
 
8
+
9
  @dataclass
10
  class StorageNameSpace:
11
  working_dir: str = None
 
19
 
20
 
21
  @dataclass
22
+ class BaseListStorage(Generic[T], StorageNameSpace):
23
+ async def all_items(self) -> list[T]:
24
+ raise NotImplementedError
25
+
26
+ async def get_by_index(self, index: int) -> Union[T, None]:
27
+ raise NotImplementedError
28
+
29
+ async def append(self, data: T):
30
+ raise NotImplementedError
31
+
32
+ async def upsert(self, data: list[T]):
33
+ raise NotImplementedError
34
+
35
+ async def drop(self):
36
+ raise NotImplementedError
37
+
38
 
39
+ @dataclass
40
+ class BaseKVStorage(Generic[T], StorageNameSpace):
41
  async def all_keys(self) -> list[str]:
42
  raise NotImplementedError
43
 
 
59
  async def drop(self):
60
  raise NotImplementedError
61
 
62
+
63
  @dataclass
64
  class BaseGraphStorage(StorageNameSpace):
65
  embedding_func: EmbeddingFunc = None
 
90
  ) -> Union[dict, None]:
91
  raise NotImplementedError
92
 
93
+ async def update_edge(
94
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
95
+ ):
96
  raise NotImplementedError
97
 
98
  async def get_all_edges(self) -> Union[list[dict], None]:
graphgen/models/storage/json_storage.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
-
3
  from dataclasses import dataclass
4
- from graphgen.utils import logger, load_json, write_json
5
- from graphgen.models.storage.base_storage import BaseKVStorage
 
6
 
7
 
8
  @dataclass
@@ -49,3 +49,39 @@ class JsonKVStorage(BaseKVStorage):
49
 
50
  async def drop(self):
51
  self._data = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  from dataclasses import dataclass
3
+
4
+ from graphgen.models.storage.base_storage import BaseKVStorage, BaseListStorage
5
+ from graphgen.utils import load_json, logger, write_json
6
 
7
 
8
  @dataclass
 
49
 
50
  async def drop(self):
51
  self._data = {}
52
+
53
+
54
+ @dataclass
55
+ class JsonListStorage(BaseListStorage):
56
+ _data: list = None
57
+
58
+ def __post_init__(self):
59
+ self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
60
+ self._data = load_json(self._file_name) or []
61
+ logger.info("Load List %s with %d data", self.namespace, len(self._data))
62
+
63
+ @property
64
+ def data(self):
65
+ return self._data
66
+
67
+ async def all_items(self) -> list:
68
+ return self._data
69
+
70
+ async def index_done_callback(self):
71
+ write_json(self._data, self._file_name)
72
+
73
+ async def get_by_index(self, index: int):
74
+ if index < 0 or index >= len(self._data):
75
+ return None
76
+ return self._data[index]
77
+
78
+ async def append(self, data):
79
+ self._data.append(data)
80
+
81
+ async def upsert(self, data: list):
82
+ left_data = [d for d in data if d not in self._data]
83
+ self._data.extend(left_data)
84
+ return left_data
85
+
86
+ async def drop(self):
87
+ self._data = []
graphgen/models/vis/__init__.py ADDED
File without changes
graphgen/models/vis/community_visualizer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+
4
+ import matplotlib.pyplot as plt
5
+ import networkx as nx
6
+
7
+
8
+ @dataclass
9
+ class Visualizer:
10
+ """
11
+ Class for visualizing graphs using NetworkX and Matplotlib.
12
+ """
13
+
14
+ graph: nx.Graph = None
15
+ communities: Dict[str, int] = None
16
+ layout: str = "spring"
17
+ max_nodes: int = 1000
18
+ node_size: int = 10
19
+ alpha: float = 0.6
20
+
21
+ def visualize(self, save_path: str = None):
22
+ n = self.graph.number_of_nodes()
23
+ if self.layout == "spring":
24
+ k = max(0.1, 1.0 / (n**0.5))
25
+ pos = nx.spring_layout(self.graph, k=k, seed=42)
26
+ else:
27
+ raise ValueError(f"Unknown layout: {self.layout}")
28
+
29
+ plt.figure(figsize=(10, 10))
30
+
31
+ node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
32
+
33
+ nx.draw_networkx_nodes(
34
+ self.graph,
35
+ pos,
36
+ node_size=self.node_size,
37
+ node_color=node_colors,
38
+ cmap=plt.cm.tab20,
39
+ alpha=self.alpha,
40
+ )
41
+ nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2)
42
+ plt.axis("off")
43
+
44
+ if save_path:
45
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
46
+ print("Saved to", save_path)
47
+ else:
48
+ plt.show()
graphgen/operators/__init__.py CHANGED
@@ -1,16 +1,22 @@
1
- from .extract_kg import extract_kg
 
 
 
 
2
  from .quiz import quiz
3
- from .judge import judge_statement, skip_judge_statement
4
- from .search_wikipedia import search_wikipedia
5
- from .traverse_graph import traverse_graph_by_edge, traverse_graph_atomically, traverse_graph_for_multi_hop
 
 
6
 
7
  __all__ = [
8
  "extract_kg",
9
  "quiz",
10
  "judge_statement",
11
- "skip_judge_statement",
12
- "search_wikipedia",
13
  "traverse_graph_by_edge",
14
  "traverse_graph_atomically",
15
- "traverse_graph_for_multi_hop"
 
16
  ]
 
1
+ from graphgen.operators.generate.generate_cot import generate_cot
2
+ from graphgen.operators.kg.extract_kg import extract_kg
3
+ from graphgen.operators.search.search_all import search_all
4
+
5
+ from .judge import judge_statement
6
  from .quiz import quiz
7
+ from .traverse_graph import (
8
+ traverse_graph_atomically,
9
+ traverse_graph_by_edge,
10
+ traverse_graph_for_multi_hop,
11
+ )
12
 
13
  __all__ = [
14
  "extract_kg",
15
  "quiz",
16
  "judge_statement",
17
+ "search_all",
 
18
  "traverse_graph_by_edge",
19
  "traverse_graph_atomically",
20
+ "traverse_graph_for_multi_hop",
21
+ "generate_cot",
22
  ]
graphgen/operators/generate/__init__.py ADDED
File without changes
graphgen/operators/generate/generate_cot.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Dict, List, Tuple
3
+
4
+ from tqdm.asyncio import tqdm as tqdm_async
5
+
6
+ from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIModel
7
+ from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
8
+ from graphgen.utils import compute_content_hash, detect_main_language
9
+
10
+
11
+ async def generate_cot(
12
+ graph_storage: NetworkXStorage,
13
+ synthesizer_llm_client: OpenAIModel,
14
+ method_params: Dict = None,
15
+ ):
16
+ method = method_params.get("method", "leiden")
17
+ detector = CommunityDetector(
18
+ graph_storage=graph_storage, method=method, method_params=method_params
19
+ )
20
+
21
+ results = await detector.detect_communities()
22
+
23
+ # Convert results to a format suitable for summarization
24
+ communities = {}
25
+ for node, community_id in results.items():
26
+ if community_id not in communities:
27
+ communities[community_id] = []
28
+ communities[community_id].append(node)
29
+
30
+ if not communities:
31
+ return {}
32
+
33
+ semaphore = asyncio.Semaphore(value=1000)
34
+
35
+ async def _generate_from_single_community(
36
+ c_id: int, nodes: List[str]
37
+ ) -> Tuple[int, Tuple[str, str, str]]:
38
+ """Summarize a single community."""
39
+ async with semaphore:
40
+ entities: List[str] = []
41
+ relationships: List[str] = []
42
+
43
+ for n in nodes:
44
+ node_data = await graph_storage.get_node(n)
45
+ if node_data is not None:
46
+ entities.append(f"({n}: {node_data.get('description')})")
47
+
48
+ edges = await graph_storage.get_node_edges(n)
49
+ for edge in edges:
50
+ target = edge[1]
51
+ if target in nodes:
52
+ edge_data = await graph_storage.get_edge(n, target)
53
+ relationships.append(
54
+ f"({n}) - [{edge_data['description']}] -> ({target})"
55
+ )
56
+
57
+ entities_str = "\n".join(entities)
58
+ relationships_str = "\n".join(relationships)
59
+
60
+ language = (
61
+ "English"
62
+ if detect_main_language(entities_str + relationships_str) == "en"
63
+ else "Chinese"
64
+ )
65
+
66
+ prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format(
67
+ entities=entities_str,
68
+ relationships=relationships_str,
69
+ )
70
+
71
+ cot_template = await synthesizer_llm_client.generate_answer(prompt)
72
+
73
+ if "问题:" in cot_template and "推理路径设计:" in cot_template:
74
+ question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip()
75
+ reasoning_path = cot_template.split("推理路径设计:")[1].strip()
76
+ elif (
77
+ "Question:" in cot_template and "Reasoning-Path Design:" in cot_template
78
+ ):
79
+ question = (
80
+ cot_template.split("Question:")[1]
81
+ .split("Reasoning-Path Design:")[0]
82
+ .strip()
83
+ )
84
+ reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip()
85
+ else:
86
+ raise ValueError("COT template format is incorrect.")
87
+
88
+ prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format(
89
+ entities=entities_str,
90
+ relationships=relationships_str,
91
+ question=question,
92
+ reasoning_template=reasoning_path,
93
+ )
94
+
95
+ cot_answer = await synthesizer_llm_client.generate_answer(prompt)
96
+
97
+ return c_id, (question, reasoning_path, cot_answer)
98
+
99
+ cid_nodes = list(communities.items())
100
+
101
+ results: Dict = {}
102
+ async for coro in tqdm_async(
103
+ asyncio.as_completed(
104
+ [_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes]
105
+ ),
106
+ total=len(cid_nodes),
107
+ desc="[Generating COT] Generating CoT data from communities",
108
+ unit="community",
109
+ ):
110
+ cid, (q, r, a) = await coro
111
+ results[compute_content_hash(q)] = {
112
+ "question": q,
113
+ "reasoning_path": r,
114
+ "answer": a,
115
+ }
116
+
117
+ return results
graphgen/operators/judge.py CHANGED
@@ -1,17 +1,20 @@
1
- import math
2
  import asyncio
 
 
3
  from tqdm.asyncio import tqdm as tqdm_async
4
- from graphgen.models import NetworkXStorage, OpenAIModel, JsonKVStorage
5
- from graphgen.utils import logger, yes_no_loss_entropy
6
  from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
 
7
 
8
 
9
- async def judge_statement( # pylint: disable=too-many-statements
10
- trainee_llm_client: OpenAIModel,
11
- graph_storage: NetworkXStorage,
12
- rephrase_storage: JsonKVStorage,
13
- re_judge: bool = False,
14
- max_concurrent: int = 1000) -> NetworkXStorage:
 
15
  """
16
  Get all edges and nodes and judge them
17
 
@@ -34,7 +37,12 @@ async def judge_statement( # pylint: disable=too-many-statements
34
  edge_data = edge[2]
35
 
36
  if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
37
- logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"])
 
 
 
 
 
38
  return source_id, target_id, edge_data
39
 
40
  description = edge_data["description"]
@@ -47,17 +55,27 @@ async def judge_statement( # pylint: disable=too-many-statements
47
  gts = [gt for _, gt in descriptions]
48
  for description, gt in descriptions:
49
  judgement = await trainee_llm_client.generate_topk_per_token(
50
- STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
 
 
51
  )
52
  judgements.append(judgement[0].top_candidates)
53
 
54
  loss = yes_no_loss_entropy(judgements, gts)
55
 
56
- logger.info("Edge %s -> %s description: %s loss: %s", source_id, target_id, description, loss)
 
 
 
 
 
 
57
 
58
  edge_data["loss"] = loss
59
- except Exception as e: # pylint: disable=broad-except
60
- logger.error("Error in judging relation %s -> %s: %s", source_id, target_id, e)
 
 
61
  logger.info("Use default loss 0.1")
62
  edge_data["loss"] = -math.log(0.1)
63
 
@@ -68,9 +86,9 @@ async def judge_statement( # pylint: disable=too-many-statements
68
 
69
  results = []
70
  for result in tqdm_async(
71
- asyncio.as_completed([_judge_single_relation(edge) for edge in edges]),
72
- total=len(edges),
73
- desc="Judging relations"
74
  ):
75
  results.append(await result)
76
 
@@ -82,7 +100,9 @@ async def judge_statement( # pylint: disable=too-many-statements
82
  node_data = node[1]
83
 
84
  if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
85
- logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"])
 
 
86
  return node_id, node_data
87
 
88
  description = node_data["description"]
@@ -95,16 +115,20 @@ async def judge_statement( # pylint: disable=too-many-statements
95
  gts = [gt for _, gt in descriptions]
96
  for description, gt in descriptions:
97
  judgement = await trainee_llm_client.generate_topk_per_token(
98
- STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
 
 
99
  )
100
  judgements.append(judgement[0].top_candidates)
101
 
102
  loss = yes_no_loss_entropy(judgements, gts)
103
 
104
- logger.info("Node %s description: %s loss: %s", node_id, description, loss)
 
 
105
 
106
  node_data["loss"] = loss
107
- except Exception as e: # pylint: disable=broad-except
108
  logger.error("Error in judging entity %s: %s", node_id, e)
109
  logger.info("Use default loss 0.1")
110
  node_data["loss"] = -math.log(0.1)
@@ -116,72 +140,9 @@ async def judge_statement( # pylint: disable=too-many-statements
116
 
117
  results = []
118
  for result in tqdm_async(
119
- asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
120
- total=len(nodes),
121
- desc="Judging entities"
122
- ):
123
- results.append(await result)
124
-
125
- return graph_storage
126
-
127
- async def skip_judge_statement(
128
- graph_storage: NetworkXStorage,
129
- max_concurrent: int = 1000
130
- ):
131
- """
132
- Skip the judgement of the statement
133
- :param graph_storage: graph storage instance
134
- :param max_concurrent: max concurrent
135
- :return:
136
- """
137
- semaphore = asyncio.Semaphore(max_concurrent)
138
-
139
- async def _skip_single_relation(
140
- edge: tuple,
141
- ):
142
- async with semaphore:
143
- source_id = edge[0]
144
- target_id = edge[1]
145
- edge_data = edge[2]
146
-
147
- if "loss" in edge_data and edge_data["loss"] is not None:
148
- logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"])
149
- return source_id, target_id, edge_data
150
-
151
- edge_data["loss"] = -math.log(0.1)
152
- await graph_storage.update_edge(source_id, target_id, edge_data)
153
- return source_id, target_id, edge_data
154
-
155
- edges = await graph_storage.get_all_edges()
156
- results = []
157
- for result in tqdm_async(
158
- asyncio.as_completed([_skip_single_relation(edge) for edge in edges]),
159
- total=len(edges),
160
- desc="Skipping judgement of relations"
161
- ):
162
- results.append(await result)
163
-
164
- async def _skip_single_entity(
165
- node: tuple,
166
- ):
167
- async with semaphore:
168
- node_id = node[0]
169
- node_data = node[1]
170
-
171
- if "loss" in node_data and node_data["loss"] is not None:
172
- logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"])
173
- return node_id, node_data
174
-
175
- node_data["loss"] = -math.log(0.1)
176
- await graph_storage.update_node(node_id, node_data)
177
- return node_id, node_data
178
-
179
- nodes = await graph_storage.get_all_nodes()
180
- results = []
181
- for result in tqdm_async(
182
- asyncio.as_completed([_skip_single_entity(node) for node in nodes]),
183
- total=len(nodes),
184
- desc="Skipping judgement of entities"
185
  ):
186
  results.append(await result)
187
 
 
 
1
  import asyncio
2
+ import math
3
+
4
  from tqdm.asyncio import tqdm as tqdm_async
5
+
6
+ from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIModel
7
  from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
8
+ from graphgen.utils import logger, yes_no_loss_entropy
9
 
10
 
11
+ async def judge_statement( # pylint: disable=too-many-statements
12
+ trainee_llm_client: OpenAIModel,
13
+ graph_storage: NetworkXStorage,
14
+ rephrase_storage: JsonKVStorage,
15
+ re_judge: bool = False,
16
+ max_concurrent: int = 1000,
17
+ ) -> NetworkXStorage:
18
  """
19
  Get all edges and nodes and judge them
20
 
 
37
  edge_data = edge[2]
38
 
39
  if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
40
+ logger.info(
41
+ "Edge %s -> %s already judged, loss: %s, skip",
42
+ source_id,
43
+ target_id,
44
+ edge_data["loss"],
45
+ )
46
  return source_id, target_id, edge_data
47
 
48
  description = edge_data["description"]
 
55
  gts = [gt for _, gt in descriptions]
56
  for description, gt in descriptions:
57
  judgement = await trainee_llm_client.generate_topk_per_token(
58
+ STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
59
+ statement=description
60
+ )
61
  )
62
  judgements.append(judgement[0].top_candidates)
63
 
64
  loss = yes_no_loss_entropy(judgements, gts)
65
 
66
+ logger.info(
67
+ "Edge %s -> %s description: %s loss: %s",
68
+ source_id,
69
+ target_id,
70
+ description,
71
+ loss,
72
+ )
73
 
74
  edge_data["loss"] = loss
75
+ except Exception as e: # pylint: disable=broad-except
76
+ logger.error(
77
+ "Error in judging relation %s -> %s: %s", source_id, target_id, e
78
+ )
79
  logger.info("Use default loss 0.1")
80
  edge_data["loss"] = -math.log(0.1)
81
 
 
86
 
87
  results = []
88
  for result in tqdm_async(
89
+ asyncio.as_completed([_judge_single_relation(edge) for edge in edges]),
90
+ total=len(edges),
91
+ desc="Judging relations",
92
  ):
93
  results.append(await result)
94
 
 
100
  node_data = node[1]
101
 
102
  if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
103
+ logger.info(
104
+ "Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
105
+ )
106
  return node_id, node_data
107
 
108
  description = node_data["description"]
 
115
  gts = [gt for _, gt in descriptions]
116
  for description, gt in descriptions:
117
  judgement = await trainee_llm_client.generate_topk_per_token(
118
+ STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
119
+ statement=description
120
+ )
121
  )
122
  judgements.append(judgement[0].top_candidates)
123
 
124
  loss = yes_no_loss_entropy(judgements, gts)
125
 
126
+ logger.info(
127
+ "Node %s description: %s loss: %s", node_id, description, loss
128
+ )
129
 
130
  node_data["loss"] = loss
131
+ except Exception as e: # pylint: disable=broad-except
132
  logger.error("Error in judging entity %s: %s", node_id, e)
133
  logger.info("Use default loss 0.1")
134
  node_data["loss"] = -math.log(0.1)
 
140
 
141
  results = []
142
  for result in tqdm_async(
143
+ asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
144
+ total=len(nodes),
145
+ desc="Judging entities",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  ):
147
  results.append(await result)
148
 
graphgen/operators/kg/__init__.py ADDED
File without changes
graphgen/operators/{extract_kg.py → kg/extract_kg.py} RENAMED
@@ -1,27 +1,33 @@
1
- import re
2
  import asyncio
3
- from typing import List
4
  from collections import defaultdict
 
5
 
6
  import gradio as gr
7
  from tqdm.asyncio import tqdm as tqdm_async
 
8
  from graphgen.models import Chunk, OpenAIModel, Tokenizer
9
  from graphgen.models.storage.base_storage import BaseGraphStorage
 
10
  from graphgen.templates import KG_EXTRACTION_PROMPT
11
- from graphgen.utils import (logger, pack_history_conversations, split_string_by_multi_markers,
12
- handle_single_entity_extraction, handle_single_relationship_extraction,
13
- detect_if_chinese)
14
- from graphgen.operators.merge_kg import merge_nodes, merge_edges
 
 
 
 
15
 
16
 
17
  # pylint: disable=too-many-statements
18
  async def extract_kg(
19
- llm_client: OpenAIModel,
20
- kg_instance: BaseGraphStorage,
21
- tokenizer_instance: Tokenizer,
22
- chunks: List[Chunk],
23
- progress_bar: gr.Progress = None,
24
- max_concurrent: int = 1000
25
  ):
26
  """
27
  :param llm_client: Synthesizer LLM model to extract entities and relationships
@@ -50,25 +56,25 @@ async def extract_kg(
50
  )
51
 
52
  final_result = await llm_client.generate_answer(hint_prompt)
53
- logger.info('First result: %s', final_result)
54
 
55
  history = pack_history_conversations(hint_prompt, final_result)
56
  for loop_index in range(max_loop):
57
  if_loop_result = await llm_client.generate_answer(
58
- text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"],
59
- history=history
60
  )
61
  if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
62
  if if_loop_result != "yes":
63
  break
64
 
65
  glean_result = await llm_client.generate_answer(
66
- text=KG_EXTRACTION_PROMPT[language]["CONTINUE"],
67
- history=history
68
  )
69
- logger.info('Loop %s glean: %s', loop_index, glean_result)
70
 
71
- history += pack_history_conversations(KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result)
 
 
72
  final_result += glean_result
73
  if loop_index == max_loop - 1:
74
  break
@@ -76,8 +82,9 @@ async def extract_kg(
76
  records = split_string_by_multi_markers(
77
  final_result,
78
  [
79
- KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
80
- KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"]],
 
81
  )
82
 
83
  nodes = defaultdict(list)
@@ -87,16 +94,20 @@ async def extract_kg(
87
  record = re.search(r"\((.*)\)", record)
88
  if record is None:
89
  continue
90
- record = record.group(1) # 提取括号内的内容
91
  record_attributes = split_string_by_multi_markers(
92
  record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
93
  )
94
 
95
- entity = await handle_single_entity_extraction(record_attributes, chunk_id)
 
 
96
  if entity is not None:
97
  nodes[entity["entity_name"]].append(entity)
98
  continue
99
- relation = await handle_single_relationship_extraction(record_attributes, chunk_id)
 
 
100
  if relation is not None:
101
  edges[(relation["src_id"], relation["tgt_id"])].append(relation)
102
  return dict(nodes), dict(edges)
@@ -106,17 +117,25 @@ async def extract_kg(
106
  async for result in tqdm_async(
107
  asyncio.as_completed([_process_single_content(c) for c in chunks]),
108
  total=len(chunks),
109
- desc="[3/4]Extracting entities and relationships from chunks",
110
  unit="chunk",
111
  ):
112
  try:
113
  if progress_bar is not None:
114
- progress_bar(len(results) / chunk_number, desc="[3/4]Extracting entities and relationships from chunks")
 
 
 
115
  results.append(await result)
116
  if progress_bar is not None and len(results) == chunk_number:
117
- progress_bar(1, desc="[3/4]Extracting entities and relationships from chunks")
118
- except Exception as e: # pylint: disable=broad-except
119
- logger.error("Error occurred while extracting entities and relationships from chunks: %s", e)
 
 
 
 
 
120
 
121
  nodes = defaultdict(list)
122
  edges = defaultdict(list)
 
 
1
  import asyncio
2
+ import re
3
  from collections import defaultdict
4
+ from typing import List
5
 
6
  import gradio as gr
7
  from tqdm.asyncio import tqdm as tqdm_async
8
+
9
  from graphgen.models import Chunk, OpenAIModel, Tokenizer
10
  from graphgen.models.storage.base_storage import BaseGraphStorage
11
+ from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
12
  from graphgen.templates import KG_EXTRACTION_PROMPT
13
+ from graphgen.utils import (
14
+ detect_if_chinese,
15
+ handle_single_entity_extraction,
16
+ handle_single_relationship_extraction,
17
+ logger,
18
+ pack_history_conversations,
19
+ split_string_by_multi_markers,
20
+ )
21
 
22
 
23
  # pylint: disable=too-many-statements
24
  async def extract_kg(
25
+ llm_client: OpenAIModel,
26
+ kg_instance: BaseGraphStorage,
27
+ tokenizer_instance: Tokenizer,
28
+ chunks: List[Chunk],
29
+ progress_bar: gr.Progress = None,
30
+ max_concurrent: int = 1000,
31
  ):
32
  """
33
  :param llm_client: Synthesizer LLM model to extract entities and relationships
 
56
  )
57
 
58
  final_result = await llm_client.generate_answer(hint_prompt)
59
+ logger.info("First result: %s", final_result)
60
 
61
  history = pack_history_conversations(hint_prompt, final_result)
62
  for loop_index in range(max_loop):
63
  if_loop_result = await llm_client.generate_answer(
64
+ text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
 
65
  )
66
  if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
67
  if if_loop_result != "yes":
68
  break
69
 
70
  glean_result = await llm_client.generate_answer(
71
+ text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
 
72
  )
73
+ logger.info("Loop %s glean: %s", loop_index, glean_result)
74
 
75
+ history += pack_history_conversations(
76
+ KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
77
+ )
78
  final_result += glean_result
79
  if loop_index == max_loop - 1:
80
  break
 
82
  records = split_string_by_multi_markers(
83
  final_result,
84
  [
85
+ KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
86
+ KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
87
+ ],
88
  )
89
 
90
  nodes = defaultdict(list)
 
94
  record = re.search(r"\((.*)\)", record)
95
  if record is None:
96
  continue
97
+ record = record.group(1) # 提取括号内的内容
98
  record_attributes = split_string_by_multi_markers(
99
  record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
100
  )
101
 
102
+ entity = await handle_single_entity_extraction(
103
+ record_attributes, chunk_id
104
+ )
105
  if entity is not None:
106
  nodes[entity["entity_name"]].append(entity)
107
  continue
108
+ relation = await handle_single_relationship_extraction(
109
+ record_attributes, chunk_id
110
+ )
111
  if relation is not None:
112
  edges[(relation["src_id"], relation["tgt_id"])].append(relation)
113
  return dict(nodes), dict(edges)
 
117
  async for result in tqdm_async(
118
  asyncio.as_completed([_process_single_content(c) for c in chunks]),
119
  total=len(chunks),
120
+ desc="[2/4]Extracting entities and relationships from chunks",
121
  unit="chunk",
122
  ):
123
  try:
124
  if progress_bar is not None:
125
+ progress_bar(
126
+ len(results) / chunk_number,
127
+ desc="[3/4]Extracting entities and relationships from chunks",
128
+ )
129
  results.append(await result)
130
  if progress_bar is not None and len(results) == chunk_number:
131
+ progress_bar(
132
+ 1, desc="[3/4]Extracting entities and relationships from chunks"
133
+ )
134
+ except Exception as e: # pylint: disable=broad-except
135
+ logger.error(
136
+ "Error occurred while extracting entities and relationships from chunks: %s",
137
+ e,
138
+ )
139
 
140
  nodes = defaultdict(list)
141
  edges = defaultdict(list)
graphgen/operators/{merge_kg.py → kg/merge_kg.py} RENAMED
@@ -1,19 +1,21 @@
1
- from collections import Counter
2
  import asyncio
 
 
3
  from tqdm.asyncio import tqdm as tqdm_async
4
 
5
- from graphgen.utils.format import split_string_by_multi_markers
6
- from graphgen.utils import logger, detect_main_language
7
- from graphgen.models import TopkTokenModel, Tokenizer
8
  from graphgen.models.storage.base_storage import BaseGraphStorage
9
- from graphgen.templates import KG_SUMMARIZATION_PROMPT, KG_EXTRACTION_PROMPT
 
 
 
10
 
11
  async def _handle_kg_summary(
12
  entity_or_relation_name: str,
13
  description: str,
14
  llm_client: TopkTokenModel,
15
  tokenizer_instance: Tokenizer,
16
- max_summary_tokens: int = 200
17
  ) -> str:
18
  """
19
  处理实体或关系的描述信息
@@ -33,17 +35,19 @@ async def _handle_kg_summary(
33
  KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
34
 
35
  tokens = tokenizer_instance.encode_string(description)
36
- if len(tokens) < max_summary_tokens:
37
  return description
38
 
39
  use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens])
40
  prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
41
  entity_name=entity_or_relation_name,
42
- description_list=use_description.split('<SEP>'),
43
- **KG_SUMMARIZATION_PROMPT["FORMAT"]
44
  )
45
  new_description = await llm_client.generate_answer(prompt)
46
- logger.info("Entity or relation %s summary: %s", entity_or_relation_name, new_description)
 
 
47
  return new_description
48
 
49
 
@@ -52,7 +56,7 @@ async def merge_nodes(
52
  kg_instance: BaseGraphStorage,
53
  llm_client: TopkTokenModel,
54
  tokenizer_instance: Tokenizer,
55
- max_concurrent: int = 1000
56
  ):
57
  """
58
  Merge nodes
@@ -77,39 +81,34 @@ async def merge_nodes(
77
  if node is not None:
78
  entity_types.append(node["entity_type"])
79
  source_ids.extend(
80
- split_string_by_multi_markers(node["source_id"], ['<SEP>'])
81
  )
82
  descriptions.append(node["description"])
83
 
84
  # 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
85
  entity_type = sorted(
86
- Counter(
87
- [dp["entity_type"] for dp in node_data] + entity_types
88
- ).items(),
89
  key=lambda x: x[1],
90
  reverse=True,
91
  )[0][0]
92
 
93
- description = '<SEP>'.join(
94
  sorted(set([dp["description"] for dp in node_data] + descriptions))
95
  )
96
  description = await _handle_kg_summary(
97
  entity_name, description, llm_client, tokenizer_instance
98
  )
99
 
100
- source_id = '<SEP>'.join(
101
  set([dp["source_id"] for dp in node_data] + source_ids)
102
  )
103
 
104
  node_data = {
105
  "entity_type": entity_type,
106
  "description": description,
107
- "source_id": source_id
108
  }
109
- await kg_instance.upsert_node(
110
- entity_name,
111
- node_data=node_data
112
- )
113
  node_data["entity_name"] = entity_name
114
  return node_data
115
 
@@ -125,7 +124,7 @@ async def merge_nodes(
125
  ):
126
  try:
127
  entities_data.append(await result)
128
- except Exception as e: # pylint: disable=broad-except
129
  logger.error("Error occurred while inserting entities into storage: %s", e)
130
 
131
 
@@ -134,7 +133,7 @@ async def merge_edges(
134
  kg_instance: BaseGraphStorage,
135
  llm_client: TopkTokenModel,
136
  tokenizer_instance: Tokenizer,
137
- max_concurrent: int = 1000
138
  ):
139
  """
140
  Merge edges
@@ -157,14 +156,14 @@ async def merge_edges(
157
  edge = await kg_instance.get_edge(src_id, tgt_id)
158
  if edge is not None:
159
  source_ids.extend(
160
- split_string_by_multi_markers(edge["source_id"], ['<SEP>'])
161
  )
162
  descriptions.append(edge["description"])
163
 
164
- description = '<SEP>'.join(
165
  sorted(set([dp["description"] for dp in edge_data] + descriptions))
166
  )
167
- source_id = '<SEP>'.join(
168
  set([dp["source_id"] for dp in edge_data] + source_ids)
169
  )
170
 
@@ -175,8 +174,8 @@ async def merge_edges(
175
  node_data={
176
  "source_id": source_id,
177
  "description": description,
178
- "entity_type": "UNKNOWN"
179
- }
180
  )
181
 
182
  description = await _handle_kg_summary(
@@ -186,24 +185,20 @@ async def merge_edges(
186
  await kg_instance.upsert_edge(
187
  src_id,
188
  tgt_id,
189
- edge_data={
190
- "source_id": source_id,
191
- "description": description
192
- }
193
  )
194
 
195
- edge_data = {
196
- "src_id": src_id,
197
- "tgt_id": tgt_id,
198
- "description": description
199
- }
200
  return edge_data
201
 
202
  logger.info("Inserting relationships into storage...")
203
  relationships_data = []
204
  for result in tqdm_async(
205
  asyncio.as_completed(
206
- [process_single_edge(src_id, tgt_id, v) for (src_id, tgt_id), v in edges_data.items()]
 
 
 
207
  ),
208
  total=len(edges_data),
209
  desc="Inserting relationships into storage",
@@ -211,5 +206,7 @@ async def merge_edges(
211
  ):
212
  try:
213
  relationships_data.append(await result)
214
- except Exception as e: # pylint: disable=broad-except
215
- logger.error("Error occurred while inserting relationships into storage: %s", e)
 
 
 
 
1
  import asyncio
2
+ from collections import Counter
3
+
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
+ from graphgen.models import Tokenizer, TopkTokenModel
 
 
7
  from graphgen.models.storage.base_storage import BaseGraphStorage
8
+ from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
9
+ from graphgen.utils import detect_main_language, logger
10
+ from graphgen.utils.format import split_string_by_multi_markers
11
+
12
 
13
  async def _handle_kg_summary(
14
  entity_or_relation_name: str,
15
  description: str,
16
  llm_client: TopkTokenModel,
17
  tokenizer_instance: Tokenizer,
18
+ max_summary_tokens: int = 200,
19
  ) -> str:
20
  """
21
  处理实体或关系的描述信息
 
35
  KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
36
 
37
  tokens = tokenizer_instance.encode_string(description)
38
+ if len(tokens) < max_summary_tokens:
39
  return description
40
 
41
  use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens])
42
  prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
43
  entity_name=entity_or_relation_name,
44
+ description_list=use_description.split("<SEP>"),
45
+ **KG_SUMMARIZATION_PROMPT["FORMAT"],
46
  )
47
  new_description = await llm_client.generate_answer(prompt)
48
+ logger.info(
49
+ "Entity or relation %s summary: %s", entity_or_relation_name, new_description
50
+ )
51
  return new_description
52
 
53
 
 
56
  kg_instance: BaseGraphStorage,
57
  llm_client: TopkTokenModel,
58
  tokenizer_instance: Tokenizer,
59
+ max_concurrent: int = 1000,
60
  ):
61
  """
62
  Merge nodes
 
81
  if node is not None:
82
  entity_types.append(node["entity_type"])
83
  source_ids.extend(
84
+ split_string_by_multi_markers(node["source_id"], ["<SEP>"])
85
  )
86
  descriptions.append(node["description"])
87
 
88
  # 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
89
  entity_type = sorted(
90
+ Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
 
 
91
  key=lambda x: x[1],
92
  reverse=True,
93
  )[0][0]
94
 
95
+ description = "<SEP>".join(
96
  sorted(set([dp["description"] for dp in node_data] + descriptions))
97
  )
98
  description = await _handle_kg_summary(
99
  entity_name, description, llm_client, tokenizer_instance
100
  )
101
 
102
+ source_id = "<SEP>".join(
103
  set([dp["source_id"] for dp in node_data] + source_ids)
104
  )
105
 
106
  node_data = {
107
  "entity_type": entity_type,
108
  "description": description,
109
+ "source_id": source_id,
110
  }
111
+ await kg_instance.upsert_node(entity_name, node_data=node_data)
 
 
 
112
  node_data["entity_name"] = entity_name
113
  return node_data
114
 
 
124
  ):
125
  try:
126
  entities_data.append(await result)
127
+ except Exception as e: # pylint: disable=broad-except
128
  logger.error("Error occurred while inserting entities into storage: %s", e)
129
 
130
 
 
133
  kg_instance: BaseGraphStorage,
134
  llm_client: TopkTokenModel,
135
  tokenizer_instance: Tokenizer,
136
+ max_concurrent: int = 1000,
137
  ):
138
  """
139
  Merge edges
 
156
  edge = await kg_instance.get_edge(src_id, tgt_id)
157
  if edge is not None:
158
  source_ids.extend(
159
+ split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
160
  )
161
  descriptions.append(edge["description"])
162
 
163
+ description = "<SEP>".join(
164
  sorted(set([dp["description"] for dp in edge_data] + descriptions))
165
  )
166
+ source_id = "<SEP>".join(
167
  set([dp["source_id"] for dp in edge_data] + source_ids)
168
  )
169
 
 
174
  node_data={
175
  "source_id": source_id,
176
  "description": description,
177
+ "entity_type": "UNKNOWN",
178
+ },
179
  )
180
 
181
  description = await _handle_kg_summary(
 
185
  await kg_instance.upsert_edge(
186
  src_id,
187
  tgt_id,
188
+ edge_data={"source_id": source_id, "description": description},
 
 
 
189
  )
190
 
191
+ edge_data = {"src_id": src_id, "tgt_id": tgt_id, "description": description}
 
 
 
 
192
  return edge_data
193
 
194
  logger.info("Inserting relationships into storage...")
195
  relationships_data = []
196
  for result in tqdm_async(
197
  asyncio.as_completed(
198
+ [
199
+ process_single_edge(src_id, tgt_id, v)
200
+ for (src_id, tgt_id), v in edges_data.items()
201
+ ]
202
  ),
203
  total=len(edges_data),
204
  desc="Inserting relationships into storage",
 
206
  ):
207
  try:
208
  relationships_data.append(await result)
209
+ except Exception as e: # pylint: disable=broad-except
210
+ logger.error(
211
+ "Error occurred while inserting relationships into storage: %s", e
212
+ )
graphgen/operators/{split_graph.py → kg/split_kg.py} RENAMED
@@ -1,14 +1,16 @@
1
  import random
2
  from collections import defaultdict
 
3
  from tqdm.asyncio import tqdm as tqdm_async
4
- from graphgen.utils import logger
5
 
6
  from graphgen.models import NetworkXStorage, TraverseStrategy
 
 
7
 
8
  async def _get_node_info(
9
  node_id: str,
10
  graph_storage: NetworkXStorage,
11
- )-> dict:
12
  """
13
  Get node info
14
 
@@ -17,10 +19,7 @@ async def _get_node_info(
17
  :return: node info
18
  """
19
  node_data = await graph_storage.get_node(node_id)
20
- return {
21
- "node_id": node_id,
22
- **node_data
23
- }
24
 
25
 
26
  def _get_level_n_edges_by_max_width(
@@ -33,7 +32,7 @@ def _get_level_n_edges_by_max_width(
33
  bidirectional: bool,
34
  max_extra_edges: int,
35
  edge_sampling: str,
36
- loss_strategy: str = "only_edge"
37
  ) -> list:
38
  """
39
  Get level n edges for an edge.
@@ -71,10 +70,17 @@ def _get_level_n_edges_by_max_width(
71
 
72
  if len(candidate_edges) >= max_extra_edges:
73
  if loss_strategy == "both":
74
- er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
75
- candidate_edges = _sort_tuples(er_tuples, edge_sampling)[:max_extra_edges]
 
 
 
 
 
76
  elif loss_strategy == "only_edge":
77
- candidate_edges = _sort_edges(candidate_edges, edge_sampling)[:max_extra_edges]
 
 
78
  else:
79
  raise ValueError(f"Invalid loss strategy: {loss_strategy}")
80
 
@@ -101,16 +107,16 @@ def _get_level_n_edges_by_max_width(
101
 
102
 
103
  def _get_level_n_edges_by_max_tokens(
104
- edge_adj_list: dict,
105
- node_dict: dict,
106
- edges: list,
107
- nodes: list,
108
- src_edge: tuple,
109
- max_depth: int,
110
- bidirectional: bool,
111
- max_tokens: int,
112
- edge_sampling: str,
113
- loss_strategy: str = "only_edge"
114
  ) -> list:
115
  """
116
  Get level n edges for an edge.
@@ -129,8 +135,11 @@ def _get_level_n_edges_by_max_tokens(
129
  """
130
  src_id, tgt_id, src_edge_data = src_edge
131
 
132
- max_tokens -= (src_edge_data["length"] + nodes[node_dict[src_id]][1]["length"]
133
- + nodes[node_dict[tgt_id]][1]["length"])
 
 
 
134
 
135
  level_n_edges = []
136
 
@@ -151,7 +160,10 @@ def _get_level_n_edges_by_max_tokens(
151
  break
152
 
153
  if loss_strategy == "both":
154
- er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
 
 
 
155
  candidate_edges = _sort_tuples(er_tuples, edge_sampling)
156
  elif loss_strategy == "only_edge":
157
  candidate_edges = _sort_edges(candidate_edges, edge_sampling)
@@ -196,15 +208,22 @@ def _sort_tuples(er_tuples: list, edge_sampling: str) -> list:
196
  if edge_sampling == "random":
197
  er_tuples = random.sample(er_tuples, len(er_tuples))
198
  elif edge_sampling == "min_loss":
199
- er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"])
 
 
 
200
  elif edge_sampling == "max_loss":
201
- er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
202
- reverse=True)
 
 
 
203
  else:
204
  raise ValueError(f"Invalid edge sampling: {edge_sampling}")
205
  edges = [edge for _, edge in er_tuples]
206
  return edges
207
 
 
208
  def _sort_edges(edges: list, edge_sampling: str) -> list:
209
  """
210
  Sort edges with edge sampling strategy
@@ -223,11 +242,12 @@ def _sort_edges(edges: list, edge_sampling: str) -> list:
223
  raise ValueError(f"Invalid edge sampling: {edge_sampling}")
224
  return edges
225
 
226
- async def get_batches_with_strategy( # pylint: disable=too-many-branches
 
227
  nodes: list,
228
  edges: list,
229
  graph_storage: NetworkXStorage,
230
- traverse_strategy: TraverseStrategy
231
  ):
232
  expand_method = traverse_strategy.expand_method
233
  if expand_method == "max_width":
@@ -256,7 +276,10 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
256
  node_dict[node_name] = i
257
 
258
  if traverse_strategy.loss_strategy == "both":
259
- er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in edges]
 
 
 
260
  edges = _sort_tuples(er_tuples, edge_sampling)
261
  elif traverse_strategy.loss_strategy == "only_edge":
262
  edges = _sort_edges(edges, edge_sampling)
@@ -279,21 +302,36 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
279
  src_id = edge[0]
280
  tgt_id = edge[1]
281
 
282
- _process_nodes.extend([await get_cached_node_info(src_id),
283
- await get_cached_node_info(tgt_id)])
 
284
  _process_edges.append(edge)
285
 
286
  if expand_method == "max_width":
287
  level_n_edges = _get_level_n_edges_by_max_width(
288
- edge_adj_list, node_dict, edges, nodes, edge, max_depth,
289
- traverse_strategy.bidirectional, traverse_strategy.max_extra_edges,
290
- edge_sampling, traverse_strategy.loss_strategy
 
 
 
 
 
 
 
291
  )
292
  else:
293
  level_n_edges = _get_level_n_edges_by_max_tokens(
294
- edge_adj_list, node_dict, edges, nodes, edge, max_depth,
295
- traverse_strategy.bidirectional, traverse_strategy.max_tokens,
296
- edge_sampling, traverse_strategy.loss_strategy
 
 
 
 
 
 
 
297
  )
298
 
299
  for _edge in level_n_edges:
@@ -302,8 +340,12 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
302
  _process_edges.append(_edge)
303
 
304
  # 去重
305
- _process_nodes = list({node['node_id']: node for node in _process_nodes}.values())
306
- _process_edges = list({(edge[0], edge[1]): edge for edge in _process_edges}.values())
 
 
 
 
307
 
308
  processing_batches.append((_process_nodes, _process_edges))
309
 
@@ -312,15 +354,21 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
312
  # isolate nodes
313
  isolated_node_strategy = traverse_strategy.isolated_node_strategy
314
  if isolated_node_strategy == "add":
315
- processing_batches = await _add_isolated_nodes(nodes, processing_batches, graph_storage)
316
- logger.info("Processing batches after adding isolated nodes: %d", len(processing_batches))
 
 
 
 
 
317
 
318
  return processing_batches
319
 
 
320
  async def _add_isolated_nodes(
321
- nodes: list,
322
- processing_batches: list,
323
- graph_storage: NetworkXStorage,
324
  ) -> list:
325
  visited_nodes = set()
326
  for _process_nodes, _process_edges in processing_batches:
 
1
  import random
2
  from collections import defaultdict
3
+
4
  from tqdm.asyncio import tqdm as tqdm_async
 
5
 
6
  from graphgen.models import NetworkXStorage, TraverseStrategy
7
+ from graphgen.utils import logger
8
+
9
 
10
  async def _get_node_info(
11
  node_id: str,
12
  graph_storage: NetworkXStorage,
13
+ ) -> dict:
14
  """
15
  Get node info
16
 
 
19
  :return: node info
20
  """
21
  node_data = await graph_storage.get_node(node_id)
22
+ return {"node_id": node_id, **node_data}
 
 
 
23
 
24
 
25
  def _get_level_n_edges_by_max_width(
 
32
  bidirectional: bool,
33
  max_extra_edges: int,
34
  edge_sampling: str,
35
+ loss_strategy: str = "only_edge",
36
  ) -> list:
37
  """
38
  Get level n edges for an edge.
 
70
 
71
  if len(candidate_edges) >= max_extra_edges:
72
  if loss_strategy == "both":
73
+ er_tuples = [
74
+ ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
75
+ for edge in candidate_edges
76
+ ]
77
+ candidate_edges = _sort_tuples(er_tuples, edge_sampling)[
78
+ :max_extra_edges
79
+ ]
80
  elif loss_strategy == "only_edge":
81
+ candidate_edges = _sort_edges(candidate_edges, edge_sampling)[
82
+ :max_extra_edges
83
+ ]
84
  else:
85
  raise ValueError(f"Invalid loss strategy: {loss_strategy}")
86
 
 
107
 
108
 
109
  def _get_level_n_edges_by_max_tokens(
110
+ edge_adj_list: dict,
111
+ node_dict: dict,
112
+ edges: list,
113
+ nodes: list,
114
+ src_edge: tuple,
115
+ max_depth: int,
116
+ bidirectional: bool,
117
+ max_tokens: int,
118
+ edge_sampling: str,
119
+ loss_strategy: str = "only_edge",
120
  ) -> list:
121
  """
122
  Get level n edges for an edge.
 
135
  """
136
  src_id, tgt_id, src_edge_data = src_edge
137
 
138
+ max_tokens -= (
139
+ src_edge_data["length"]
140
+ + nodes[node_dict[src_id]][1]["length"]
141
+ + nodes[node_dict[tgt_id]][1]["length"]
142
+ )
143
 
144
  level_n_edges = []
145
 
 
160
  break
161
 
162
  if loss_strategy == "both":
163
+ er_tuples = [
164
+ ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
165
+ for edge in candidate_edges
166
+ ]
167
  candidate_edges = _sort_tuples(er_tuples, edge_sampling)
168
  elif loss_strategy == "only_edge":
169
  candidate_edges = _sort_edges(candidate_edges, edge_sampling)
 
208
  if edge_sampling == "random":
209
  er_tuples = random.sample(er_tuples, len(er_tuples))
210
  elif edge_sampling == "min_loss":
211
+ er_tuples = sorted(
212
+ er_tuples,
213
+ key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
214
+ )
215
  elif edge_sampling == "max_loss":
216
+ er_tuples = sorted(
217
+ er_tuples,
218
+ key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
219
+ reverse=True,
220
+ )
221
  else:
222
  raise ValueError(f"Invalid edge sampling: {edge_sampling}")
223
  edges = [edge for _, edge in er_tuples]
224
  return edges
225
 
226
+
227
  def _sort_edges(edges: list, edge_sampling: str) -> list:
228
  """
229
  Sort edges with edge sampling strategy
 
242
  raise ValueError(f"Invalid edge sampling: {edge_sampling}")
243
  return edges
244
 
245
+
246
+ async def get_batches_with_strategy( # pylint: disable=too-many-branches
247
  nodes: list,
248
  edges: list,
249
  graph_storage: NetworkXStorage,
250
+ traverse_strategy: TraverseStrategy,
251
  ):
252
  expand_method = traverse_strategy.expand_method
253
  if expand_method == "max_width":
 
276
  node_dict[node_name] = i
277
 
278
  if traverse_strategy.loss_strategy == "both":
279
+ er_tuples = [
280
+ ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
281
+ for edge in edges
282
+ ]
283
  edges = _sort_tuples(er_tuples, edge_sampling)
284
  elif traverse_strategy.loss_strategy == "only_edge":
285
  edges = _sort_edges(edges, edge_sampling)
 
302
  src_id = edge[0]
303
  tgt_id = edge[1]
304
 
305
+ _process_nodes.extend(
306
+ [await get_cached_node_info(src_id), await get_cached_node_info(tgt_id)]
307
+ )
308
  _process_edges.append(edge)
309
 
310
  if expand_method == "max_width":
311
  level_n_edges = _get_level_n_edges_by_max_width(
312
+ edge_adj_list,
313
+ node_dict,
314
+ edges,
315
+ nodes,
316
+ edge,
317
+ max_depth,
318
+ traverse_strategy.bidirectional,
319
+ traverse_strategy.max_extra_edges,
320
+ edge_sampling,
321
+ traverse_strategy.loss_strategy,
322
  )
323
  else:
324
  level_n_edges = _get_level_n_edges_by_max_tokens(
325
+ edge_adj_list,
326
+ node_dict,
327
+ edges,
328
+ nodes,
329
+ edge,
330
+ max_depth,
331
+ traverse_strategy.bidirectional,
332
+ traverse_strategy.max_tokens,
333
+ edge_sampling,
334
+ traverse_strategy.loss_strategy,
335
  )
336
 
337
  for _edge in level_n_edges:
 
340
  _process_edges.append(_edge)
341
 
342
  # 去重
343
+ _process_nodes = list(
344
+ {node["node_id"]: node for node in _process_nodes}.values()
345
+ )
346
+ _process_edges = list(
347
+ {(edge[0], edge[1]): edge for edge in _process_edges}.values()
348
+ )
349
 
350
  processing_batches.append((_process_nodes, _process_edges))
351
 
 
354
  # isolate nodes
355
  isolated_node_strategy = traverse_strategy.isolated_node_strategy
356
  if isolated_node_strategy == "add":
357
+ processing_batches = await _add_isolated_nodes(
358
+ nodes, processing_batches, graph_storage
359
+ )
360
+ logger.info(
361
+ "Processing batches after adding isolated nodes: %d",
362
+ len(processing_batches),
363
+ )
364
 
365
  return processing_batches
366
 
367
+
368
  async def _add_isolated_nodes(
369
+ nodes: list,
370
+ processing_batches: list,
371
+ graph_storage: NetworkXStorage,
372
  ) -> list:
373
  visited_nodes = set()
374
  for _process_nodes, _process_edges in processing_batches:
graphgen/operators/preprocess/__init__.py ADDED
File without changes
graphgen/operators/{resolute_coreference.py → preprocess/resolute_coreference.py} RENAMED
@@ -1,12 +1,13 @@
1
  from typing import List
2
- from graphgen.models import Chunk
3
- from graphgen.models import OpenAIModel
4
- from graphgen.templates import COREFERENCE_RESOLUTION_TEMPLATE
5
  from graphgen.utils import detect_main_language
6
 
 
7
  async def resolute_coreference(
8
- llm_client: OpenAIModel,
9
- chunks: List[Chunk]) -> List[Chunk]:
10
  """
11
  Resolute conference
12
 
@@ -23,9 +24,8 @@ async def resolute_coreference(
23
  for _, chunk in enumerate(chunks[1:]):
24
  language = detect_main_language(chunk.content)
25
  result = await llm_client.generate_answer(
26
- COREFERENCE_RESOLUTION_TEMPLATE[language].format(
27
- reference = results[0].content,
28
- input_sentence = chunk.content
29
  )
30
  )
31
  results.append(Chunk(id=chunk.id, content=result))
 
1
  from typing import List
2
+
3
+ from graphgen.models import Chunk, OpenAIModel
4
+ from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
5
  from graphgen.utils import detect_main_language
6
 
7
+
8
  async def resolute_coreference(
9
+ llm_client: OpenAIModel, chunks: List[Chunk]
10
+ ) -> List[Chunk]:
11
  """
12
  Resolute conference
13
 
 
24
  for _, chunk in enumerate(chunks[1:]):
25
  language = detect_main_language(chunk.content)
26
  result = await llm_client.generate_answer(
27
+ COREFERENCE_RESOLUTION_PROMPT[language].format(
28
+ reference=results[0].content, input_sentence=chunk.content
 
29
  )
30
  )
31
  results.append(Chunk(id=chunk.id, content=result))
graphgen/operators/search/__init__.py ADDED
File without changes
graphgen/operators/search/db/__init__.py ADDED
File without changes
graphgen/operators/search/db/search_uniprot.py ADDED
File without changes
graphgen/operators/search/kg/__init__.py ADDED
File without changes
graphgen/operators/search/kg/search_wikipedia.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm.asyncio import tqdm_asyncio as tqdm_async
2
+
3
+ from graphgen.models import WikiSearch
4
+ from graphgen.utils import logger
5
+
6
+
7
+ async def _process_single_entity(
8
+ entity_name: str,
9
+ wiki_search_client: WikiSearch,
10
+ ) -> str | None:
11
+ """
12
+ Process single entity by searching Wikipedia
13
+ :param entity_name
14
+ :param wiki_search_client
15
+ :return: summary of the entity or None if not found
16
+ """
17
+ search_results = await wiki_search_client.search(entity_name)
18
+ if not search_results:
19
+ return None
20
+
21
+ summary = None
22
+ try:
23
+ summary = await wiki_search_client.summary(search_results[-1])
24
+ logger.info(
25
+ "Entity %s search result: %s summary: %s",
26
+ entity_name,
27
+ str(search_results),
28
+ summary,
29
+ )
30
+ except Exception as e: # pylint: disable=broad-except
31
+ logger.error("Error processing entity %s: %s", entity_name, str(e))
32
+
33
+ return summary
34
+
35
+
36
+ async def search_wikipedia(
37
+ wiki_search_client: WikiSearch,
38
+ entities: set[str],
39
+ ) -> dict:
40
+ """
41
+ Search wikipedia for entities
42
+
43
+ :param wiki_search_client: wiki search client
44
+ :param entities: list of entities to search
45
+ :return: nodes with search results
46
+ """
47
+ wiki_data = {}
48
+
49
+ async for entity in tqdm_async(
50
+ entities, desc="Searching Wikipedia", total=len(entities)
51
+ ):
52
+ try:
53
+ summary = await _process_single_entity(entity, wiki_search_client)
54
+ if summary:
55
+ wiki_data[entity] = summary
56
+ except Exception as e: # pylint: disable=broad-except
57
+ logger.error("Error processing entity %s: %s", entity, str(e))
58
+ return wiki_data
graphgen/operators/search/search_all.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ To use Google Web Search API,
3
+ follow the instructions [here](https://developers.google.com/custom-search/v1/overview)
4
+ to get your Google search api key.
5
+
6
+ To use Bing Web Search API,
7
+ follow the instructions [here](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api)
8
+ and obtain your Bing subscription key.
9
+ """
10
+
11
+ import os
12
+
13
+ from graphgen.utils import logger
14
+
15
+
16
+ async def search_all(
17
+ search_types: dict, search_entities: set[str]
18
+ ) -> dict[str, dict[str, str]]:
19
+ """
20
+ :param search_types
21
+ :param search_entities: list of entities to search
22
+ :return: nodes with search results
23
+ """
24
+
25
+ results = {}
26
+
27
+ for search_type in search_types:
28
+ if search_type == "wikipedia":
29
+ from graphgen.models import WikiSearch
30
+ from graphgen.operators.search.kg.search_wikipedia import search_wikipedia
31
+
32
+ wiki_search_client = WikiSearch()
33
+
34
+ wiki_results = await search_wikipedia(wiki_search_client, search_entities)
35
+ for entity_name, description in wiki_results.items():
36
+ if description:
37
+ results[entity_name] = {"wikipedia": description}
38
+ elif search_type == "google":
39
+ from graphgen.models import GoogleSearch
40
+ from graphgen.operators.search.web.search_google import search_google
41
+
42
+ google_search_client = GoogleSearch(
43
+ subscription_key=os.environ["GOOGLE_SEARCH_API_KEY"],
44
+ cx=os.environ["GOOGLE_SEARCH_CX"],
45
+ )
46
+
47
+ google_results = await search_google(google_search_client, search_entities)
48
+ for entity_name, description in google_results.items():
49
+ if description:
50
+ results[entity_name] = results.get(entity_name, {})
51
+ results[entity_name]["google"] = description
52
+ elif search_type == "bing":
53
+ from graphgen.models import BingSearch
54
+ from graphgen.operators.search.web.search_bing import search_bing
55
+
56
+ bing_search_client = BingSearch(
57
+ subscription_key=os.environ["BING_SEARCH_API_KEY"]
58
+ )
59
+
60
+ bing_results = await search_bing(bing_search_client, search_entities)
61
+ for entity_name, description in bing_results.items():
62
+ if description:
63
+ results[entity_name] = results.get(entity_name, {})
64
+ results[entity_name]["bing"] = description
65
+ elif search_type == "uniprot":
66
+ # from graphgen.models import UniProtSearch
67
+ # from graphgen.operators.search.db.search_uniprot import search_uniprot
68
+ #
69
+ # uniprot_search_client = UniProtSearch()
70
+ #
71
+ # uniprot_results = await search_uniprot(
72
+ # uniprot_search_client, search_entities
73
+ # )
74
+ raise NotImplementedError(
75
+ "Processing of UniProt search results is not implemented yet."
76
+ )
77
+
78
+ else:
79
+ logger.error("Search type %s is not supported yet.", search_type)
80
+ continue
81
+
82
+ return results
graphgen/operators/search/web/__init__.py ADDED
File without changes
graphgen/operators/search/web/search_bing.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import trafilatura
2
+ from tqdm.asyncio import tqdm_asyncio as tqdm_async
3
+
4
+ from graphgen.models import BingSearch
5
+ from graphgen.utils import logger
6
+
7
+
8
+ async def _process_single_entity(
9
+ entity_name: str, bing_search_client: BingSearch
10
+ ) -> str | None:
11
+ """
12
+ Process single entity by searching Bing.
13
+ :param entity_name: The name of the entity to search.
14
+ :param bing_search_client: The Bing search client.
15
+ :return: Summary of the entity or None if not found.
16
+ """
17
+ search_results = bing_search_client.search(entity_name)
18
+ if not search_results:
19
+ return None
20
+
21
+ # Get more details from the first search result
22
+ first_result = search_results[0]
23
+ content = trafilatura.fetch_url(first_result["url"])
24
+ summary = trafilatura.extract(content, include_comments=False, include_links=False)
25
+ summary = summary.strip()
26
+ logger.info(
27
+ "Entity %s search result: %s",
28
+ entity_name,
29
+ summary,
30
+ )
31
+ return summary
32
+
33
+
34
+ async def search_bing(
35
+ bing_search_client: BingSearch,
36
+ entities: set[str],
37
+ ) -> dict[str, str]:
38
+ """
39
+ Search with Bing and return the contexts.
40
+ :return:
41
+ """
42
+ bing_data = {}
43
+
44
+ async for entity in tqdm_async(
45
+ entities, desc="Searching Bing", total=len(entities)
46
+ ):
47
+ try:
48
+ summary = await _process_single_entity(entity, bing_search_client)
49
+ if summary:
50
+ bing_data[entity] = summary
51
+ except Exception as e: # pylint: disable=broad-except
52
+ logger.error("Error processing entity %s: %s", entity, str(e))
53
+ return bing_data
graphgen/operators/search/web/search_google.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import trafilatura
2
+ from tqdm.asyncio import tqdm_asyncio as tqdm_async
3
+
4
+ from graphgen.models import GoogleSearch
5
+ from graphgen.utils import logger
6
+
7
+
8
+ async def _process_single_entity(
9
+ entity_name: str, google_search_client: GoogleSearch
10
+ ) -> str | None:
11
+ search_results = google_search_client.search(entity_name)
12
+ if not search_results:
13
+ return None
14
+
15
+ # Get more details from the first search result
16
+ first_result = search_results[0]
17
+ content = trafilatura.fetch_url(first_result["link"])
18
+ summary = trafilatura.extract(content, include_comments=False, include_links=False)
19
+ summary = summary.strip()
20
+ logger.info(
21
+ "Entity %s search result: %s",
22
+ entity_name,
23
+ summary,
24
+ )
25
+ return summary
26
+
27
+
28
+ async def search_google(
29
+ google_search_client: GoogleSearch,
30
+ entities: set[str],
31
+ ) -> dict:
32
+ """
33
+ Search with Google and return the contexts.
34
+ :param google_search_client: Google search client
35
+ :param entities: list of entities to search
36
+ :return:
37
+ """
38
+ google_data = {}
39
+
40
+ async for entity in tqdm_async(
41
+ entities, desc="Searching Google", total=len(entities)
42
+ ):
43
+ try:
44
+ summary = await _process_single_entity(entity, google_search_client)
45
+ if summary:
46
+ google_data[entity] = summary
47
+ except Exception as e: # pylint: disable=broad-except
48
+ logger.error("Error processing entity %s: %s", entity, str(e))
49
+ return google_data
graphgen/operators/search_wikipedia.py DELETED
@@ -1,71 +0,0 @@
1
- import asyncio
2
- from graphgen.models import WikiSearch, OpenAIModel
3
- from graphgen.models.storage.base_storage import BaseGraphStorage
4
- from graphgen.templates import SEARCH_JUDGEMENT_PROMPT
5
- from graphgen.utils import logger
6
-
7
-
8
- async def _process_single_entity(entity_name: str,
9
- description: str,
10
- llm_client: OpenAIModel,
11
- wiki_search_client: WikiSearch) -> tuple[str, None] | tuple[str, str]:
12
- """
13
- Process single entity
14
-
15
- """
16
- search_results = await wiki_search_client.search(entity_name)
17
- if not search_results:
18
- return entity_name, None
19
- examples = "\n".join(SEARCH_JUDGEMENT_PROMPT["EXAMPLES"])
20
- search_results.append("None of the above")
21
-
22
- search_results_str = "\n".join([f"{i + 1}. {sr}" for i, sr in enumerate(search_results)])
23
- prompt = SEARCH_JUDGEMENT_PROMPT["TEMPLATE"].format(
24
- examples=examples,
25
- entity_name=entity_name,
26
- description=description,
27
- search_results=search_results_str,
28
- )
29
- response = await llm_client.generate_answer(prompt)
30
-
31
- try:
32
- response = response.strip()
33
- response = int(response)
34
- if response < 1 or response >= len(search_results):
35
- response = None
36
- else:
37
- response = await wiki_search_client.summary(search_results[response - 1])
38
- except ValueError:
39
- response = None
40
-
41
- logger.info("Entity %s search result: %s response: %s", entity_name, str(search_results), response)
42
-
43
- return entity_name, response
44
-
45
- async def search_wikipedia(llm_client: OpenAIModel,
46
- wiki_search_client: WikiSearch,
47
- knowledge_graph_instance: BaseGraphStorage,) -> dict:
48
- """
49
- Search wikipedia for entities
50
-
51
- :param llm_client: LLM model
52
- :param wiki_search_client: wiki search client
53
- :param knowledge_graph_instance: knowledge graph instance
54
- :return: nodes with search results
55
- """
56
-
57
-
58
- nodes = await knowledge_graph_instance.get_all_nodes()
59
- nodes = list(nodes)
60
- wiki_data = {}
61
-
62
- tasks = [
63
- _process_single_entity(node[0].strip('"'), node[1]["description"], llm_client, wiki_search_client)
64
- for node in nodes
65
- ]
66
-
67
- for task in asyncio.as_completed(tasks):
68
- result = await task
69
- wiki_data[result[0]] = result[1]
70
-
71
- return wiki_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/operators/traverse_graph.py CHANGED
@@ -1,49 +1,67 @@
1
  import asyncio
2
- import gradio as gr
3
 
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
- from graphgen.models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer, JsonKVStorage
7
- from graphgen.templates import ANSWER_REPHRASING_PROMPT, QUESTION_GENERATION_PROMPT, MULTI_HOP_GENERATION_PROMPT
8
- from graphgen.utils import detect_main_language, compute_content_hash, logger
9
- from graphgen.operators.split_graph import get_batches_with_strategy
10
-
11
-
12
- async def _pre_tokenize(graph_storage: NetworkXStorage,
13
- tokenizer: Tokenizer,
14
- edges: list,
15
- nodes: list) -> tuple:
 
 
 
 
 
 
 
 
 
16
 
17
  sem = asyncio.Semaphore(1000)
 
18
  async def handle_edge(edge: tuple) -> tuple:
19
  async with sem:
20
- if 'length' not in edge[2]:
21
- edge[2]['length'] = len(
22
- await asyncio.get_event_loop().run_in_executor(None,
23
- tokenizer.encode_string,
24
- edge[2]['description']))
 
25
  return edge
26
 
27
  async def handle_node(node: dict) -> dict:
28
  async with sem:
29
- if 'length' not in node[1]:
30
- node[1]['length'] = len(
31
- await asyncio.get_event_loop().run_in_executor(None,
32
- tokenizer.encode_string,
33
- node[1]['description']))
 
34
  return node
35
 
36
  new_edges = []
37
  new_nodes = []
38
 
39
- for result in tqdm_async(asyncio.as_completed([handle_edge(edge) for edge in edges]),
40
- total=len(edges), desc="Pre-tokenizing edges"):
 
 
 
41
  new_edge = await result
42
  await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2])
43
  new_edges.append(new_edge)
44
 
45
- for result in tqdm_async(asyncio.as_completed([handle_node(node) for node in nodes]),
46
- total=len(nodes), desc="Pre-tokenizing nodes"):
 
 
 
47
  new_node = await result
48
  await graph_storage.update_node(new_node[0], new_node[1])
49
  new_nodes.append(new_node)
@@ -51,60 +69,75 @@ async def _pre_tokenize(graph_storage: NetworkXStorage,
51
  await graph_storage.index_done_callback()
52
  return new_edges, new_nodes
53
 
54
- async def _construct_rephrasing_prompt(_process_nodes: list,
55
- _process_edges: list,
56
- text_chunks_storage: JsonKVStorage,
57
- add_context: bool = False
58
- ) -> str:
 
 
59
  entities = [
60
- f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
 
61
  ]
62
  relations = [
63
  f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
64
  for _process_edge in _process_edges
65
  ]
66
 
67
- entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])
68
- relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])
69
- language = "Chinese" if detect_main_language(entities_str + relations_str) == "zh" else "English"
 
 
 
 
 
 
 
 
70
 
71
  if add_context:
72
- original_ids = ([node['source_id'].split('<SEP>')[0] for node in _process_nodes] +
73
- [edge[2]['source_id'].split('<SEP>')[0] for edge in _process_edges])
 
74
 
75
  original_ids = list(set(original_ids))
76
  original_text = await text_chunks_storage.get_by_ids(original_ids)
77
- original_text = "\n".join([f"{index + 1}. {text['content']}" for index, text in enumerate(original_text)])
 
 
 
 
 
78
 
79
- prompt = ANSWER_REPHRASING_PROMPT[language]['CONTEXT_TEMPLATE'].format(
80
  language=language,
81
  original_text=original_text,
82
  entities=entities_str,
83
- relationships=relations_str
84
  )
85
  return prompt
86
 
87
- prompt = ANSWER_REPHRASING_PROMPT[language]['TEMPLATE'].format(
88
- language=language,
89
- entities=entities_str,
90
- relationships=relations_str
91
  )
92
  return prompt
93
 
94
- def get_loss_tercile(losses: list) -> (float, float):
95
- losses = sorted(losses)
96
- q1_index = int(len(losses) * (1 / 3))
97
- q2_index = int(len(losses) * (2 / 3))
98
-
99
- return losses[q1_index], losses[q2_index]
100
 
101
  def get_average_loss(batch: tuple, loss_strategy: str) -> float:
102
- if loss_strategy == "only_edge":
103
- return sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
104
- if loss_strategy == "both":
105
- return sum(edge[2]['loss'] for edge in batch[1]) + sum(node['loss'] for node in batch[0]) / \
106
- (len(batch[0]) + len(batch[1]))
107
- raise ValueError("Invalid loss strategy")
 
 
 
 
 
 
108
 
109
  def _post_process_synthetic_data(data):
110
  block = data.split("\n\n")
@@ -113,26 +146,18 @@ def _post_process_synthetic_data(data):
113
  if "Question:" in line and "Answer:" in line:
114
  question = line.split("Question:")[1].split("Answer:")[0].strip()
115
  answer = line.split("Answer:")[1].strip()
116
- qas.append({
117
- "question": question,
118
- "answer": answer
119
- })
120
  elif "问题:" in line and "答案:" in line:
121
  question = line.split("问题:")[1].split("答案:")[0].strip()
122
  answer = line.split("答案:")[1].strip()
123
- qas.append({
124
- "question": question,
125
- "answer": answer
126
- })
127
  elif "问题:" in line and "回答:" in line:
128
  question = line.split("问题:")[1].split("回答:")[0].strip()
129
  answer = line.split("回答:")[1].strip()
130
- qas.append({
131
- "question": question,
132
- "answer": answer
133
- })
134
  return qas
135
 
 
136
  async def traverse_graph_by_edge(
137
  llm_client: OpenAIModel,
138
  tokenizer: Tokenizer,
@@ -140,7 +165,7 @@ async def traverse_graph_by_edge(
140
  traverse_strategy: TraverseStrategy,
141
  text_chunks_storage: JsonKVStorage,
142
  progress_bar: gr.Progress = None,
143
- max_concurrent: int = 1000
144
  ) -> dict:
145
  """
146
  Traverse the graph
@@ -158,28 +183,24 @@ async def traverse_graph_by_edge(
158
  semaphore = asyncio.Semaphore(max_concurrent)
159
 
160
  async def _process_nodes_and_edges(
161
- _process_nodes: list,
162
- _process_edges: list,
163
  ) -> str:
164
  prompt = await _construct_rephrasing_prompt(
165
- _process_nodes,
166
- _process_edges,
167
- text_chunks_storage,
168
- add_context = False
169
  )
170
  context = await llm_client.generate_answer(prompt)
171
 
172
  # post-process the context
173
  if context.startswith("Rephrased Text:"):
174
- context = context[len("Rephrased Text:"):].strip()
175
  elif context.startswith("重述文本:"):
176
- context = context[len("重述文本:"):].strip()
177
 
178
  return context
179
 
180
  async def _process_single_batch(
181
- _process_batch: tuple,
182
- question_type: str = "single"
183
  ) -> dict:
184
  async with semaphore:
185
  context = await _process_nodes_and_edges(
@@ -188,21 +209,26 @@ async def traverse_graph_by_edge(
188
  )
189
 
190
  language = "Chinese" if detect_main_language(context) == "zh" else "English"
191
- pre_length = sum(node['length'] for node in _process_batch[0]) \
192
- + sum(edge[2]['length'] for edge in _process_batch[1])
 
193
 
194
  if question_type == "single":
195
  question = await llm_client.generate_answer(
196
- QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
197
  answer=context
198
  )
199
  )
200
  if question.startswith("Question:"):
201
- question = question[len("Question:"):].strip()
202
  elif question.startswith("问题:"):
203
- question = question[len("问题:"):].strip()
204
 
205
- logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
 
 
 
 
206
  logger.info("Pre-length: %s", pre_length)
207
  logger.info("Question: %s", question)
208
  logger.info("Answer: %s", context)
@@ -211,12 +237,14 @@ async def traverse_graph_by_edge(
211
  compute_content_hash(context): {
212
  "question": question,
213
  "answer": context,
214
- "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy)
 
 
215
  }
216
  }
217
 
218
  content = await llm_client.generate_answer(
219
- QUESTION_GENERATION_PROMPT[language]['MULTI_TEMPLATE'].format(
220
  doc=context
221
  )
222
  )
@@ -224,19 +252,27 @@ async def traverse_graph_by_edge(
224
 
225
  if len(qas) == 0:
226
  print(content)
227
- logger.error("Error occurred while processing batch, question or answer is None")
 
 
228
  return {}
229
 
230
  final_results = {}
231
- logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
 
 
 
 
232
  logger.info("Pre-length: %s", pre_length)
233
  for qa in qas:
234
- logger.info("Question: %s", qa['question'])
235
- logger.info("Answer: %s", qa['answer'])
236
- final_results[compute_content_hash(qa['question'])] = {
237
- "question": qa['question'],
238
- "answer": qa['answer'],
239
- "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy)
 
 
240
  }
241
  return final_results
242
 
@@ -247,22 +283,25 @@ async def traverse_graph_by_edge(
247
  edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
248
 
249
  processing_batches = await get_batches_with_strategy(
250
- nodes,
251
- edges,
252
- graph_storage,
253
- traverse_strategy
254
  )
255
 
256
- for result in tqdm_async(asyncio.as_completed(
257
- [_process_single_batch(batch) for batch in processing_batches]
258
- ), total=len(processing_batches), desc="[4/4]Generating QAs"):
 
 
 
 
259
  try:
260
  if progress_bar is not None:
261
- progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
 
 
262
  results.update(await result)
263
  if progress_bar is not None and len(results) == len(processing_batches):
264
  progress_bar(1, desc="[4/4]Generating QAs")
265
- except Exception as e: # pylint: disable=broad-except
266
  logger.error("Error occurred while generating QA: %s", e)
267
 
268
  return results
@@ -275,7 +314,7 @@ async def traverse_graph_atomically(
275
  traverse_strategy: TraverseStrategy,
276
  text_chunks_storage: JsonKVStorage,
277
  progress_bar: gr.Progress = None,
278
- max_concurrent: int = 1000
279
  ) -> dict:
280
  """
281
  Traverse the graph atomicly
@@ -292,22 +331,21 @@ async def traverse_graph_atomically(
292
  assert traverse_strategy.qa_form == "atomic"
293
 
294
  semaphore = asyncio.Semaphore(max_concurrent)
295
- async def _generate_question(
296
- node_or_edge: tuple
297
- ):
298
  if len(node_or_edge) == 2:
299
- des = node_or_edge[0] + ": " + node_or_edge[1]['description']
300
- loss = node_or_edge[1]['loss']
301
  else:
302
- des = node_or_edge[2]['description']
303
- loss = node_or_edge[2]['loss']
304
 
305
  async with semaphore:
306
  try:
307
  language = "Chinese" if detect_main_language(des) == "zh" else "English"
308
 
309
  qa = await llm_client.generate_answer(
310
- QUESTION_GENERATION_PROMPT[language]['SINGLE_QA_TEMPLATE'].format(
311
  doc=des
312
  )
313
  )
@@ -321,8 +359,8 @@ async def traverse_graph_atomically(
321
  else:
322
  return {}
323
 
324
- question = question.strip("\"")
325
- answer = answer.strip("\"")
326
 
327
  logger.info("Question: %s", question)
328
  logger.info("Answer: %s", answer)
@@ -330,10 +368,10 @@ async def traverse_graph_atomically(
330
  compute_content_hash(question): {
331
  "question": question,
332
  "answer": answer,
333
- "loss": loss
334
  }
335
  }
336
- except Exception as e: # pylint: disable=broad-except
337
  logger.error("Error occurred while generating question: %s", e)
338
  return {}
339
 
@@ -345,24 +383,26 @@ async def traverse_graph_atomically(
345
 
346
  tasks = []
347
  for node in nodes:
348
- if "<SEP>" in node[1]['description']:
349
- description_list = node[1]['description'].split("<SEP>")
350
  for item in description_list:
351
- tasks.append((node[0], {"description": item, 'loss': node[1]['loss']}))
352
  else:
353
  tasks.append((node[0], node[1]))
354
  for edge in edges:
355
- if "<SEP>" in edge[2]['description']:
356
- description_list = edge[2]['description'].split("<SEP>")
357
  for item in description_list:
358
- tasks.append((edge[0], edge[1], {"description": item, 'loss': edge[2]['loss']}))
 
 
359
  else:
360
  tasks.append((edge[0], edge[1], edge[2]))
361
 
362
  for result in tqdm_async(
363
  asyncio.as_completed([_generate_question(task) for task in tasks]),
364
  total=len(tasks),
365
- desc="[4/4]Generating QAs"
366
  ):
367
  try:
368
  if progress_bar is not None:
@@ -370,10 +410,11 @@ async def traverse_graph_atomically(
370
  results.update(await result)
371
  if progress_bar is not None and len(results) == len(tasks):
372
  progress_bar(1, desc="[4/4]Generating QAs")
373
- except Exception as e: # pylint: disable=broad-except
374
  logger.error("Error occurred while generating QA: %s", e)
375
  return results
376
 
 
377
  async def traverse_graph_for_multi_hop(
378
  llm_client: OpenAIModel,
379
  tokenizer: Tokenizer,
@@ -381,7 +422,7 @@ async def traverse_graph_for_multi_hop(
381
  traverse_strategy: TraverseStrategy,
382
  text_chunks_storage: JsonKVStorage,
383
  progress_bar: gr.Progress = None,
384
- max_concurrent: int = 1000
385
  ) -> dict:
386
  """
387
  Traverse the graph for multi-hop
@@ -395,8 +436,6 @@ async def traverse_graph_for_multi_hop(
395
  :param max_concurrent
396
  :return: question and answer
397
  """
398
- assert traverse_strategy.qa_form == "multi_hop"
399
-
400
  semaphore = asyncio.Semaphore(max_concurrent)
401
 
402
  results = {}
@@ -406,24 +445,24 @@ async def traverse_graph_for_multi_hop(
406
  edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
407
 
408
  processing_batches = await get_batches_with_strategy(
409
- nodes,
410
- edges,
411
- graph_storage,
412
- traverse_strategy
413
  )
414
 
415
- async def _process_single_batch(
416
- _process_batch: tuple
417
- ) -> dict:
418
  async with semaphore:
419
  try:
420
- language = "Chinese" if detect_main_language(_process_batch[0][0]['description']) == "zh" else "English"
 
 
 
 
421
 
422
  _process_nodes = _process_batch[0]
423
  _process_edges = _process_batch[1]
424
 
425
  entities = [
426
- f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
 
427
  ]
428
 
429
  relations = [
@@ -431,12 +470,18 @@ async def traverse_graph_for_multi_hop(
431
  for _process_edge in _process_edges
432
  ]
433
 
434
- entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])
435
- relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])
 
 
 
 
 
 
 
436
 
437
  prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
438
- entities=entities_str,
439
- relationships=relations_str
440
  )
441
 
442
  context = await llm_client.generate_answer(prompt)
@@ -451,8 +496,8 @@ async def traverse_graph_for_multi_hop(
451
  else:
452
  return {}
453
 
454
- question = question.strip("\"")
455
- answer = answer.strip("\"")
456
 
457
  logger.info("Question: %s", question)
458
  logger.info("Answer: %s", answer)
@@ -461,25 +506,31 @@ async def traverse_graph_for_multi_hop(
461
  compute_content_hash(question): {
462
  "question": question,
463
  "answer": answer,
464
- "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
 
 
465
  }
466
  }
467
 
468
- except Exception as e: # pylint: disable=broad-except
469
  logger.error("Error occurred while processing batch: %s", e)
470
  return {}
471
 
472
  async for result in tqdm_async(
473
- asyncio.as_completed([_process_single_batch(batch) for batch in processing_batches]),
 
 
474
  total=len(processing_batches),
475
- desc="[4/4]Generating QAs"
476
  ):
477
  try:
478
  if progress_bar is not None:
479
- progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
 
 
480
  results.update(await result)
481
  if progress_bar is not None and len(results) == len(processing_batches):
482
  progress_bar(1, desc="[4/4]Generating QAs")
483
- except Exception as e: # pylint: disable=broad-except
484
  logger.error("Error occurred while generating QA: %s", e)
485
  return results
 
1
  import asyncio
 
2
 
3
+ import gradio as gr
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
+ from graphgen.models import (
7
+ JsonKVStorage,
8
+ NetworkXStorage,
9
+ OpenAIModel,
10
+ Tokenizer,
11
+ TraverseStrategy,
12
+ )
13
+ from graphgen.operators.kg.split_kg import get_batches_with_strategy
14
+ from graphgen.templates import (
15
+ ANSWER_REPHRASING_PROMPT,
16
+ MULTI_HOP_GENERATION_PROMPT,
17
+ QUESTION_GENERATION_PROMPT,
18
+ )
19
+ from graphgen.utils import compute_content_hash, detect_main_language, logger
20
+
21
+
22
+ async def _pre_tokenize(
23
+ graph_storage: NetworkXStorage, tokenizer: Tokenizer, edges: list, nodes: list
24
+ ) -> tuple:
25
 
26
  sem = asyncio.Semaphore(1000)
27
+
28
  async def handle_edge(edge: tuple) -> tuple:
29
  async with sem:
30
+ if "length" not in edge[2]:
31
+ edge[2]["length"] = len(
32
+ await asyncio.get_event_loop().run_in_executor(
33
+ None, tokenizer.encode_string, edge[2]["description"]
34
+ )
35
+ )
36
  return edge
37
 
38
  async def handle_node(node: dict) -> dict:
39
  async with sem:
40
+ if "length" not in node[1]:
41
+ node[1]["length"] = len(
42
+ await asyncio.get_event_loop().run_in_executor(
43
+ None, tokenizer.encode_string, node[1]["description"]
44
+ )
45
+ )
46
  return node
47
 
48
  new_edges = []
49
  new_nodes = []
50
 
51
+ for result in tqdm_async(
52
+ asyncio.as_completed([handle_edge(edge) for edge in edges]),
53
+ total=len(edges),
54
+ desc="Pre-tokenizing edges",
55
+ ):
56
  new_edge = await result
57
  await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2])
58
  new_edges.append(new_edge)
59
 
60
+ for result in tqdm_async(
61
+ asyncio.as_completed([handle_node(node) for node in nodes]),
62
+ total=len(nodes),
63
+ desc="Pre-tokenizing nodes",
64
+ ):
65
  new_node = await result
66
  await graph_storage.update_node(new_node[0], new_node[1])
67
  new_nodes.append(new_node)
 
69
  await graph_storage.index_done_callback()
70
  return new_edges, new_nodes
71
 
72
+
73
+ async def _construct_rephrasing_prompt(
74
+ _process_nodes: list,
75
+ _process_edges: list,
76
+ text_chunks_storage: JsonKVStorage,
77
+ add_context: bool = False,
78
+ ) -> str:
79
  entities = [
80
+ f"{_process_node['node_id']}: {_process_node['description']}"
81
+ for _process_node in _process_nodes
82
  ]
83
  relations = [
84
  f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
85
  for _process_edge in _process_edges
86
  ]
87
 
88
+ entities_str = "\n".join(
89
+ [f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
90
+ )
91
+ relations_str = "\n".join(
92
+ [f"{index + 1}. {relation}" for index, relation in enumerate(relations)]
93
+ )
94
+ language = (
95
+ "Chinese"
96
+ if detect_main_language(entities_str + relations_str) == "zh"
97
+ else "English"
98
+ )
99
 
100
  if add_context:
101
+ original_ids = [
102
+ node["source_id"].split("<SEP>")[0] for node in _process_nodes
103
+ ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
104
 
105
  original_ids = list(set(original_ids))
106
  original_text = await text_chunks_storage.get_by_ids(original_ids)
107
+ original_text = "\n".join(
108
+ [
109
+ f"{index + 1}. {text['content']}"
110
+ for index, text in enumerate(original_text)
111
+ ]
112
+ )
113
 
114
+ prompt = ANSWER_REPHRASING_PROMPT[language]["CONTEXT_TEMPLATE"].format(
115
  language=language,
116
  original_text=original_text,
117
  entities=entities_str,
118
+ relationships=relations_str,
119
  )
120
  return prompt
121
 
122
+ prompt = ANSWER_REPHRASING_PROMPT[language]["TEMPLATE"].format(
123
+ language=language, entities=entities_str, relationships=relations_str
 
 
124
  )
125
  return prompt
126
 
 
 
 
 
 
 
127
 
128
  def get_average_loss(batch: tuple, loss_strategy: str) -> float:
129
+ try:
130
+ if loss_strategy == "only_edge":
131
+ return sum(edge[2]["loss"] for edge in batch[1]) / len(batch[1])
132
+ if loss_strategy == "both":
133
+ return sum(edge[2]["loss"] for edge in batch[1]) + sum(
134
+ node["loss"] for node in batch[0]
135
+ ) / (len(batch[0]) + len(batch[1]))
136
+ raise ValueError("Invalid loss strategy")
137
+ except Exception as e: # pylint: disable=broad-except
138
+ logger.error("Error calculating average loss: %s", e)
139
+ return -1.0
140
+
141
 
142
  def _post_process_synthetic_data(data):
143
  block = data.split("\n\n")
 
146
  if "Question:" in line and "Answer:" in line:
147
  question = line.split("Question:")[1].split("Answer:")[0].strip()
148
  answer = line.split("Answer:")[1].strip()
149
+ qas.append({"question": question, "answer": answer})
 
 
 
150
  elif "问题:" in line and "答案:" in line:
151
  question = line.split("问题:")[1].split("答案:")[0].strip()
152
  answer = line.split("答案:")[1].strip()
153
+ qas.append({"question": question, "answer": answer})
 
 
 
154
  elif "问题:" in line and "回答:" in line:
155
  question = line.split("问题:")[1].split("回答:")[0].strip()
156
  answer = line.split("回答:")[1].strip()
157
+ qas.append({"question": question, "answer": answer})
 
 
 
158
  return qas
159
 
160
+
161
  async def traverse_graph_by_edge(
162
  llm_client: OpenAIModel,
163
  tokenizer: Tokenizer,
 
165
  traverse_strategy: TraverseStrategy,
166
  text_chunks_storage: JsonKVStorage,
167
  progress_bar: gr.Progress = None,
168
+ max_concurrent: int = 1000,
169
  ) -> dict:
170
  """
171
  Traverse the graph
 
183
  semaphore = asyncio.Semaphore(max_concurrent)
184
 
185
  async def _process_nodes_and_edges(
186
+ _process_nodes: list,
187
+ _process_edges: list,
188
  ) -> str:
189
  prompt = await _construct_rephrasing_prompt(
190
+ _process_nodes, _process_edges, text_chunks_storage, add_context=False
 
 
 
191
  )
192
  context = await llm_client.generate_answer(prompt)
193
 
194
  # post-process the context
195
  if context.startswith("Rephrased Text:"):
196
+ context = context[len("Rephrased Text:") :].strip()
197
  elif context.startswith("重述文本:"):
198
+ context = context[len("重述文本:") :].strip()
199
 
200
  return context
201
 
202
  async def _process_single_batch(
203
+ _process_batch: tuple, question_type: str = "single"
 
204
  ) -> dict:
205
  async with semaphore:
206
  context = await _process_nodes_and_edges(
 
209
  )
210
 
211
  language = "Chinese" if detect_main_language(context) == "zh" else "English"
212
+ pre_length = sum(node["length"] for node in _process_batch[0]) + sum(
213
+ edge[2]["length"] for edge in _process_batch[1]
214
+ )
215
 
216
  if question_type == "single":
217
  question = await llm_client.generate_answer(
218
+ QUESTION_GENERATION_PROMPT[language]["SINGLE_TEMPLATE"].format(
219
  answer=context
220
  )
221
  )
222
  if question.startswith("Question:"):
223
+ question = question[len("Question:") :].strip()
224
  elif question.startswith("问题:"):
225
+ question = question[len("问题:") :].strip()
226
 
227
+ logger.info(
228
+ "%d nodes and %d edges processed",
229
+ len(_process_batch[0]),
230
+ len(_process_batch[1]),
231
+ )
232
  logger.info("Pre-length: %s", pre_length)
233
  logger.info("Question: %s", question)
234
  logger.info("Answer: %s", context)
 
237
  compute_content_hash(context): {
238
  "question": question,
239
  "answer": context,
240
+ "loss": get_average_loss(
241
+ _process_batch, traverse_strategy.loss_strategy
242
+ ),
243
  }
244
  }
245
 
246
  content = await llm_client.generate_answer(
247
+ QUESTION_GENERATION_PROMPT[language]["MULTI_TEMPLATE"].format(
248
  doc=context
249
  )
250
  )
 
252
 
253
  if len(qas) == 0:
254
  print(content)
255
+ logger.error(
256
+ "Error occurred while processing batch, question or answer is None"
257
+ )
258
  return {}
259
 
260
  final_results = {}
261
+ logger.info(
262
+ "%d nodes and %d edges processed",
263
+ len(_process_batch[0]),
264
+ len(_process_batch[1]),
265
+ )
266
  logger.info("Pre-length: %s", pre_length)
267
  for qa in qas:
268
+ logger.info("Question: %s", qa["question"])
269
+ logger.info("Answer: %s", qa["answer"])
270
+ final_results[compute_content_hash(qa["question"])] = {
271
+ "question": qa["question"],
272
+ "answer": qa["answer"],
273
+ "loss": get_average_loss(
274
+ _process_batch, traverse_strategy.loss_strategy
275
+ ),
276
  }
277
  return final_results
278
 
 
283
  edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
284
 
285
  processing_batches = await get_batches_with_strategy(
286
+ nodes, edges, graph_storage, traverse_strategy
 
 
 
287
  )
288
 
289
+ for result in tqdm_async(
290
+ asyncio.as_completed(
291
+ [_process_single_batch(batch) for batch in processing_batches]
292
+ ),
293
+ total=len(processing_batches),
294
+ desc="[4/4]Generating QAs",
295
+ ):
296
  try:
297
  if progress_bar is not None:
298
+ progress_bar(
299
+ len(results) / len(processing_batches), desc="[4/4]Generating QAs"
300
+ )
301
  results.update(await result)
302
  if progress_bar is not None and len(results) == len(processing_batches):
303
  progress_bar(1, desc="[4/4]Generating QAs")
304
+ except Exception as e: # pylint: disable=broad-except
305
  logger.error("Error occurred while generating QA: %s", e)
306
 
307
  return results
 
314
  traverse_strategy: TraverseStrategy,
315
  text_chunks_storage: JsonKVStorage,
316
  progress_bar: gr.Progress = None,
317
+ max_concurrent: int = 1000,
318
  ) -> dict:
319
  """
320
  Traverse the graph atomicly
 
331
  assert traverse_strategy.qa_form == "atomic"
332
 
333
  semaphore = asyncio.Semaphore(max_concurrent)
334
+
335
+ async def _generate_question(node_or_edge: tuple):
 
336
  if len(node_or_edge) == 2:
337
+ des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
338
+ loss = node_or_edge[1]["loss"]
339
  else:
340
+ des = node_or_edge[2]["description"]
341
+ loss = node_or_edge[2]["loss"]
342
 
343
  async with semaphore:
344
  try:
345
  language = "Chinese" if detect_main_language(des) == "zh" else "English"
346
 
347
  qa = await llm_client.generate_answer(
348
+ QUESTION_GENERATION_PROMPT[language]["SINGLE_QA_TEMPLATE"].format(
349
  doc=des
350
  )
351
  )
 
359
  else:
360
  return {}
361
 
362
+ question = question.strip('"')
363
+ answer = answer.strip('"')
364
 
365
  logger.info("Question: %s", question)
366
  logger.info("Answer: %s", answer)
 
368
  compute_content_hash(question): {
369
  "question": question,
370
  "answer": answer,
371
+ "loss": loss,
372
  }
373
  }
374
+ except Exception as e: # pylint: disable=broad-except
375
  logger.error("Error occurred while generating question: %s", e)
376
  return {}
377
 
 
383
 
384
  tasks = []
385
  for node in nodes:
386
+ if "<SEP>" in node[1]["description"]:
387
+ description_list = node[1]["description"].split("<SEP>")
388
  for item in description_list:
389
+ tasks.append((node[0], {"description": item, "loss": node[1]["loss"]}))
390
  else:
391
  tasks.append((node[0], node[1]))
392
  for edge in edges:
393
+ if "<SEP>" in edge[2]["description"]:
394
+ description_list = edge[2]["description"].split("<SEP>")
395
  for item in description_list:
396
+ tasks.append(
397
+ (edge[0], edge[1], {"description": item, "loss": edge[2]["loss"]})
398
+ )
399
  else:
400
  tasks.append((edge[0], edge[1], edge[2]))
401
 
402
  for result in tqdm_async(
403
  asyncio.as_completed([_generate_question(task) for task in tasks]),
404
  total=len(tasks),
405
+ desc="[4/4]Generating QAs",
406
  ):
407
  try:
408
  if progress_bar is not None:
 
410
  results.update(await result)
411
  if progress_bar is not None and len(results) == len(tasks):
412
  progress_bar(1, desc="[4/4]Generating QAs")
413
+ except Exception as e: # pylint: disable=broad-except
414
  logger.error("Error occurred while generating QA: %s", e)
415
  return results
416
 
417
+
418
  async def traverse_graph_for_multi_hop(
419
  llm_client: OpenAIModel,
420
  tokenizer: Tokenizer,
 
422
  traverse_strategy: TraverseStrategy,
423
  text_chunks_storage: JsonKVStorage,
424
  progress_bar: gr.Progress = None,
425
+ max_concurrent: int = 1000,
426
  ) -> dict:
427
  """
428
  Traverse the graph for multi-hop
 
436
  :param max_concurrent
437
  :return: question and answer
438
  """
 
 
439
  semaphore = asyncio.Semaphore(max_concurrent)
440
 
441
  results = {}
 
445
  edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
446
 
447
  processing_batches = await get_batches_with_strategy(
448
+ nodes, edges, graph_storage, traverse_strategy
 
 
 
449
  )
450
 
451
+ async def _process_single_batch(_process_batch: tuple) -> dict:
 
 
452
  async with semaphore:
453
  try:
454
+ language = (
455
+ "Chinese"
456
+ if detect_main_language(_process_batch[0][0]["description"]) == "zh"
457
+ else "English"
458
+ )
459
 
460
  _process_nodes = _process_batch[0]
461
  _process_edges = _process_batch[1]
462
 
463
  entities = [
464
+ f"{_process_node['node_id']}: {_process_node['description']}"
465
+ for _process_node in _process_nodes
466
  ]
467
 
468
  relations = [
 
470
  for _process_edge in _process_edges
471
  ]
472
 
473
+ entities_str = "\n".join(
474
+ [f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
475
+ )
476
+ relations_str = "\n".join(
477
+ [
478
+ f"{index + 1}. {relation}"
479
+ for index, relation in enumerate(relations)
480
+ ]
481
+ )
482
 
483
  prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
484
+ entities=entities_str, relationships=relations_str
 
485
  )
486
 
487
  context = await llm_client.generate_answer(prompt)
 
496
  else:
497
  return {}
498
 
499
+ question = question.strip('"')
500
+ answer = answer.strip('"')
501
 
502
  logger.info("Question: %s", question)
503
  logger.info("Answer: %s", answer)
 
506
  compute_content_hash(question): {
507
  "question": question,
508
  "answer": answer,
509
+ "loss": get_average_loss(
510
+ _process_batch, traverse_strategy.loss_strategy
511
+ ),
512
  }
513
  }
514
 
515
+ except Exception as e: # pylint: disable=broad-except
516
  logger.error("Error occurred while processing batch: %s", e)
517
  return {}
518
 
519
  async for result in tqdm_async(
520
+ asyncio.as_completed(
521
+ [_process_single_batch(batch) for batch in processing_batches]
522
+ ),
523
  total=len(processing_batches),
524
+ desc="[4/4]Generating QAs",
525
  ):
526
  try:
527
  if progress_bar is not None:
528
+ progress_bar(
529
+ len(results) / len(processing_batches), desc="[4/4]Generating QAs"
530
+ )
531
  results.update(await result)
532
  if progress_bar is not None and len(results) == len(processing_batches):
533
  progress_bar(1, desc="[4/4]Generating QAs")
534
+ except Exception as e: # pylint: disable=broad-except
535
  logger.error("Error occurred while generating QA: %s", e)
536
  return results