File size: 6,174 Bytes
af7f5dd fa6baed af7f5dd cdf717a fa6baed 9590c46 af7f5dd 9590c46 af7f5dd 9590c46 af7f5dd 9590c46 af7f5dd 9590c46 af7f5dd 9590c46 af7f5dd 9590c46 af7f5dd 9590c46 af7f5dd ab4c4d1 af7f5dd 9590c46 af7f5dd 9590c46 af7f5dd 9590c46 af7f5dd 9590c46 af7f5dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import json
import xml.etree.ElementTree as ET
from neo4j import GraphDatabase
# Constants
WORKING_DIR = "./dickens"
BATCH_SIZE_NODES = 500
BATCH_SIZE_EDGES = 100
# Neo4j connection credentials
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your_password"
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
root = tree.getroot()
# Print the root element's tag and attributes to confirm the file has been correctly loaded
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
data = {"nodes": [], "edges": []}
# Use namespace
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
for node in root.findall(".//node", namespace):
node_data = {
"id": node.get("id").strip('"'),
"entity_type": node.find("./data[@key='d1']", namespace).text.strip('"')
if node.find("./data[@key='d1']", namespace) is not None
else "",
"description": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
"source_id": node.find("./data[@key='d3']", namespace).text
if node.find("./data[@key='d3']", namespace) is not None
else "",
}
data["nodes"].append(node_data)
for edge in root.findall(".//edge", namespace):
edge_data = {
"source": edge.get("source").strip('"'),
"target": edge.get("target").strip('"'),
"weight": float(edge.find("./data[@key='d5']", namespace).text)
if edge.find("./data[@key='d5']", namespace) is not None
else 0.0,
"description": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
"keywords": edge.find("./data[@key='d7']", namespace).text
if edge.find("./data[@key='d7']", namespace) is not None
else "",
"source_id": edge.find("./data[@key='d8']", namespace).text
if edge.find("./data[@key='d8']", namespace) is not None
else "",
}
data["edges"].append(edge_data)
# Print the number of nodes and edges found
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
return data
except ET.ParseError as e:
print(f"Error parsing XML file: {e}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
def convert_xml_to_json(xml_path, output_path):
"""Converts XML file to JSON and saves the output."""
if not os.path.exists(xml_path):
print(f"Error: File not found - {xml_path}")
return None
json_data = xml_to_json(xml_path)
if json_data:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}")
return json_data
else:
print("Failed to create JSON data")
return None
def process_in_batches(tx, query, data, batch_size):
"""Process data in batches and execute the given query."""
for i in range(0, len(data), batch_size):
batch = data[i : i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
def main():
# Paths
xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
json_file = os.path.join(WORKING_DIR, "graph_data.json")
# Convert XML to JSON
json_data = convert_xml_to_json(xml_file, json_file)
if json_data is None:
return
# Load nodes and edges
nodes = json_data.get("nodes", [])
edges = json_data.get("edges", [])
# Neo4j queries
create_nodes_query = """
UNWIND $nodes AS node
MERGE (e:Entity {id: node.id})
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
e.displayName = node.id
REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode
RETURN count(*)
"""
create_edges_query = """
UNWIND $edges AS edge
MATCH (source {id: edge.source})
MATCH (target {id: edge.target})
WITH source, target, edge,
CASE
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
WHEN edge.keywords CONTAINS 'located' THEN 'located'
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
END AS relType
CALL apoc.create.relationship(source, relType, {
weight: edge.weight,
description: edge.description,
keywords: edge.keywords,
source_id: edge.source_id
}, target) YIELD rel
RETURN count(*)
"""
set_displayname_and_labels_query = """
MATCH (n)
SET n.displayName = n.id
WITH n
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
RETURN count(*)
"""
# Create a Neo4j driver
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
try:
# Execute queries in batches
with driver.session() as session:
# Insert nodes in batches
session.execute_write(
process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
)
# Insert edges in batches
session.execute_write(
process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
)
# Set displayName and labels
session.run(set_displayname_and_labels_query)
except Exception as e:
print(f"Error occurred: {e}")
finally:
driver.close()
if __name__ == "__main__":
main()
|