mtyrrell commited on
Commit
f2fa19f
·
1 Parent(s): 6fcf71e

ingestor node added

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -0
  2. app/main.py +98 -5
  3. params.cfg +3 -0
Dockerfile CHANGED
@@ -1,3 +1,5 @@
 
 
1
  FROM python:3.10-slim
2
 
3
  WORKDIR /app
 
1
+ #CHATFED_ORCHESTRATOR
2
+
3
  FROM python:3.10-slim
4
 
5
  WORKDIR /app
app/main.py CHANGED
@@ -1,9 +1,10 @@
1
- # Gradio UI not currenlty working.
 
2
  import gradio as gr
3
- from fastapi import FastAPI
4
  from langserve import add_routes
5
  from langgraph.graph import StateGraph, START, END
6
- from typing import Optional, Dict, Any
7
  from typing_extensions import TypedDict
8
  from pydantic import BaseModel
9
  from gradio_client import Client
@@ -14,12 +15,14 @@ import logging
14
  from contextlib import asynccontextmanager
15
  import threading
16
  from langchain_core.runnables import RunnableLambda
 
17
 
18
  from utils import getconfig
19
 
20
  config = getconfig("params.cfg")
21
  RETRIEVER = config.get("retriever", "RETRIEVER")
22
  GENERATOR = config.get("generator", "GENERATOR")
 
23
 
24
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
25
  logger = logging.getLogger(__name__)
@@ -29,11 +32,15 @@ logger = logging.getLogger(__name__)
29
  class GraphState(TypedDict):
30
  query: str
31
  context: str
 
32
  result: str
33
  reports_filter: str
34
  sources_filter: str
35
  subtype_filter: str
36
  year_filter: str
 
 
 
37
  metadata: Optional[Dict[str, Any]]
38
 
39
  class ChatFedInput(TypedDict):
@@ -44,6 +51,8 @@ class ChatFedInput(TypedDict):
44
  year_filter: Optional[str]
45
  session_id: Optional[str]
46
  user_id: Optional[str]
 
 
47
 
48
  class ChatFedOutput(TypedDict):
49
  result: str
@@ -53,6 +62,76 @@ class ChatUIInput(BaseModel):
53
  text: str
54
 
55
  # Module functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def retrieve_node(state: GraphState) -> GraphState:
57
  start_time = datetime.now()
58
  logger.info(f"Retrieval: {state['query'][:50]}...")
@@ -95,10 +174,22 @@ def generate_node(state: GraphState) -> GraphState:
95
  logger.info(f"Generation: {state['query'][:50]}...")
96
 
97
  try:
 
 
 
 
 
 
 
 
 
 
 
 
98
  client = Client(GENERATOR)
99
  result = client.predict(
100
  query=state["query"],
101
- context=state["context"],
102
  api_name="/generate"
103
  )
104
 
@@ -126,9 +217,11 @@ def generate_node(state: GraphState) -> GraphState:
126
 
127
  # start the graph
128
  workflow = StateGraph(GraphState)
 
129
  workflow.add_node("retrieve", retrieve_node)
130
  workflow.add_node("generate", generate_node)
131
- workflow.add_edge(START, "retrieve")
 
132
  workflow.add_edge("retrieve", "generate")
133
  workflow.add_edge("generate", END)
134
  compiled_graph = workflow.compile()
 
1
+ #CHATFED_ORCHESTRATOR
2
+
3
  import gradio as gr
4
+ from fastapi import FastAPI, UploadFile, File, Form
5
  from langserve import add_routes
6
  from langgraph.graph import StateGraph, START, END
7
+ from typing import Optional, Dict, Any, List
8
  from typing_extensions import TypedDict
9
  from pydantic import BaseModel
10
  from gradio_client import Client
 
15
  from contextlib import asynccontextmanager
16
  import threading
17
  from langchain_core.runnables import RunnableLambda
18
+ import tempfile
19
 
20
  from utils import getconfig
21
 
22
  config = getconfig("params.cfg")
23
  RETRIEVER = config.get("retriever", "RETRIEVER")
24
  GENERATOR = config.get("generator", "GENERATOR")
25
+ INGESTOR = config.get("ingestor", "INGESTOR")
26
 
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
28
  logger = logging.getLogger(__name__)
 
32
  class GraphState(TypedDict):
33
  query: str
34
  context: str
35
+ ingestor_context: str
36
  result: str
37
  reports_filter: str
38
  sources_filter: str
39
  subtype_filter: str
40
  year_filter: str
41
+ file_content: Optional[bytes]
42
+ filename: Optional[str]
43
+ doc_id: Optional[str]
44
  metadata: Optional[Dict[str, Any]]
45
 
46
  class ChatFedInput(TypedDict):
 
51
  year_filter: Optional[str]
52
  session_id: Optional[str]
53
  user_id: Optional[str]
54
+ file_content: Optional[bytes]
55
+ filename: Optional[str]
56
 
57
  class ChatFedOutput(TypedDict):
58
  result: str
 
62
  text: str
63
 
64
  # Module functions
65
+ def ingest_node(state: GraphState) -> GraphState:
66
+ """Process file through ingestor if file is provided"""
67
+ start_time = datetime.now()
68
+
69
+ # If no file provided, skip this step
70
+ if not state.get("file_content") or not state.get("filename"):
71
+ logger.info("No file provided, skipping ingestion")
72
+ return {"ingestor_context": "", "metadata": state.get("metadata", {})}
73
+
74
+ logger.info(f"Ingesting file: {state['filename']}")
75
+
76
+ try:
77
+ client = Client(INGESTOR)
78
+
79
+ # Create a temporary file to upload
80
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
81
+ tmp_file.write(state["file_content"])
82
+ tmp_file_path = tmp_file.name
83
+
84
+ try:
85
+ # Call the ingestor's /ingest endpoint
86
+ ingest_result = client.predict(
87
+ file=tmp_file_path,
88
+ api_name="/ingest"
89
+ )
90
+
91
+ # Extract doc_id from result
92
+ # The ingest endpoint returns an IngestResponse object
93
+ doc_id = ingest_result.get("doc_id") if isinstance(ingest_result, dict) else ingest_result
94
+
95
+ # Get processed context using doc_id
96
+ context_result = client.predict(
97
+ doc_id=doc_id,
98
+ max_chunks=10, # configurable
99
+ api_name="/context"
100
+ )
101
+
102
+ ingestor_context = context_result.get("context", "") if isinstance(context_result, dict) else str(context_result)
103
+
104
+ finally:
105
+ # Clean up temporary file
106
+ os.unlink(tmp_file_path)
107
+
108
+ duration = (datetime.now() - start_time).total_seconds()
109
+ metadata = state.get("metadata", {})
110
+ metadata.update({
111
+ "ingestion_duration": duration,
112
+ "doc_id": doc_id,
113
+ "ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
114
+ "ingestion_success": True
115
+ })
116
+
117
+ return {
118
+ "ingestor_context": ingestor_context,
119
+ "doc_id": doc_id,
120
+ "metadata": metadata
121
+ }
122
+
123
+ except Exception as e:
124
+ duration = (datetime.now() - start_time).total_seconds()
125
+ logger.error(f"Ingestion failed: {str(e)}")
126
+
127
+ metadata = state.get("metadata", {})
128
+ metadata.update({
129
+ "ingestion_duration": duration,
130
+ "ingestion_success": False,
131
+ "ingestion_error": str(e)
132
+ })
133
+ return {"ingestor_context": "", "metadata": metadata}
134
+
135
  def retrieve_node(state: GraphState) -> GraphState:
136
  start_time = datetime.now()
137
  logger.info(f"Retrieval: {state['query'][:50]}...")
 
174
  logger.info(f"Generation: {state['query'][:50]}...")
175
 
176
  try:
177
+ # Combine retriever context with ingestor context
178
+ retrieved_context = state.get("context", "")
179
+ ingestor_context = state.get("ingestor_context", "")
180
+
181
+ combined_context = ""
182
+ if ingestor_context and retrieved_context:
183
+ combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_context}\n\n=== RETRIEVED CONTEXT ===\n{retrieved_context}"
184
+ elif ingestor_context:
185
+ combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_context}"
186
+ elif retrieved_context:
187
+ combined_context = retrieved_context
188
+
189
  client = Client(GENERATOR)
190
  result = client.predict(
191
  query=state["query"],
192
+ context=combined_context,
193
  api_name="/generate"
194
  )
195
 
 
217
 
218
  # start the graph
219
  workflow = StateGraph(GraphState)
220
+ workflow.add_node("ingest", ingest_node)
221
  workflow.add_node("retrieve", retrieve_node)
222
  workflow.add_node("generate", generate_node)
223
+ workflow.add_edge(START, "ingest")
224
+ workflow.add_edge("ingest", "retrieve")
225
  workflow.add_edge("retrieve", "generate")
226
  workflow.add_edge("generate", END)
227
  compiled_graph = workflow.compile()
params.cfg CHANGED
@@ -3,3 +3,6 @@ RETRIEVER = giz/chatfed_retriever
3
 
4
  [generator]
5
  GENERATOR = giz/chatfed_generator
 
 
 
 
3
 
4
  [generator]
5
  GENERATOR = giz/chatfed_generator
6
+
7
+ [ingestor]
8
+ INGESTOR = mtyrrell/chatfed_ingestor