File size: 6,404 Bytes
1d53ee2
 
 
 
 
89affd8
b804e50
03aaeb7
06698e0
aa75487
03aaeb7
1d53ee2
 
aa75487
1d53ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be85ca1
1d53ee2
 
 
89b9ccd
 
1d53ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89b9ccd
1d53ee2
 
 
 
 
 
 
 
 
 
 
89affd8
1d53ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import pandas as pd
import torch
import os
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
from huggingface_hub import login
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_classic.chains import RetrievalQA
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import HuggingFacePipeline
from langchain_community.document_loaders.csv_loader import CSVLoader
import transformers
from langchain_core.documents import Document
import gradio as gr
import re


model = "abnuel/MedGemma-4b-ICD"

#tokenizer = AutoTokenizer.from_pretrained("abnuel/MedGemma-4b-ICD")

SYSTEM_PROMPT = "You are an expert medical coder. Your task is to analyze the clinical description provided and output only the single, most appropriate ICD-10-CM code. Do not include any text, justification other than the code itself."

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

OFFLOAD_FOLDER = "model_offload_dir"
model = AutoModelForCausalLM.from_pretrained(
    model,
    quantization_config=bnb_config,
    device_map="auto",
    offload_folder=OFFLOAD_FOLDER
)


def generate_response(clinical_note):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": f"Code the following: {clinical_note}"},
    ]

    # 3. Apply chat template and tokenize

    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(model.device)

    input_len = inputs["input_ids"].shape[-1]

    # 4. Generate the response
    with torch.inference_mode():
        generation = model.generate(
            **inputs,
            max_new_tokens=200,      # Max length of the generated ICD codes
            do_sample=False,         # Use greedy decoding for predictable output
            temperature=0.0,         # Zero temperature for deterministic results
        )

    # 5. Decode the output
    # Extract only the newly generated tokens
    generation = generation[0][input_len:]
    decoded_output = tokenizer.decode(generation, skip_special_tokens=True)

    return decoded_output.strip()

# --- Example Usage ---
#test_note = "Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."


#print(f"Clinical Note: {test_note}")
#response = generate_response(test_note)
#print(f"Generated ICD Codes: {response}")



pipe = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=50,
    temperature=0.1,
    do_sample=False,
    pad_token_id=tokenizer.eos_token_id,

)
hf_llm = HuggingFacePipeline(pipeline=pipe)


df = pd.read_csv("./medical_coding_train_1.csv")

documents = [
    Document(
        page_content=f"note: {row['note']}\nicd_code: {row['icd_codes']}",
        metadata={"icd_code": row["icd_codes"]}
    )
    for _, row in df.iterrows()
]


# 2. Chunk Documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = text_splitter.split_documents(documents)

# 3. Create Embeddings and Vector Store (FAISS)
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
db = FAISS.from_documents(docs, embeddings)
retriever = db.as_retriever(search_kwargs={"k": 2}) 


RAG_PROMPT_TEMPLATE = """
You are an expert medical coder.
Your task is to determine the most accurate ICD-10-CM code for the given clinical note.

Use ONLY the following context (which may include ICD codes from similar cases).
If you cannot determine a match from the provided context, respond exactly with:
"I cannot find the code in the provided documents."

Return ONLY the ICD-10-CM code itself — no explanation, no text, no punctuation.

Context:
{context}

Clinical Note:
{question}

ICD-10-CM Code:
"""


rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
#direct_chain = LLMChain(llm=hf_llm, prompt=rag_prompt)


# 5. Create the QA Chain
qa_chain = RetrievalQA.from_chain_type(
    llm=hf_llm,
    chain_type="stuff",
    retriever=retriever,
    return_source_documents=False,
    chain_type_kwargs={"prompt": rag_prompt}
)

def extract_icd_code(text):
    # Pattern to match "ICD-10-CM Code:" followed by the code
    pattern = r'ICD-10-CM Code:\s*([A-Z0-9.]+)'

    match = re.search(pattern, text)

    if match:
        return match.group(1)
    return None

def generate_code_rag(clinical_note, retriever, threshold=0.35):
    """Generates the ICD code using RAG."""

    # Format the user question for the RAG prompt template
    query = f"Code the following: {clinical_note}"

     # Step 1: Retrieve docs
    docs_and_scores = db.similarity_search_with_score(query, k=2)

    # Step 2: Filter by similarity threshold
    relevant_docs = [doc for doc, score in docs_and_scores if score > threshold]

    

    if relevant_docs:
        #print(qa_chain)
        result = qa_chain({"query": query})["result"]
        
        #answer = result['result']
        icd_code = extract_icd_code(result)
        #print(icd_code)
        if icd_code == None:
          print("I got here")
          result = generate_response(clinical_note)
          return result
        else:
          return icd_code
    
        # Step 4: Otherwise, use LLM directly (no context)
        

# Create the Gradio Interface
gr.Interface(
    fn=generate_code_rag,
    inputs=gr.Textbox(lines=5, label="Enter Clinical Note Here", placeholder="e.g., Patient presented with simple laceration of the left hand."),
    outputs=gr.Textbox(label="Predicted ICD-10 Code"),
    title="ClaimSwift Medical Coding",
    description="",
    examples=[
        ["Benign neoplasm of peripheral nerves and autonomic nervous system of face, head, and neck"],
        ["Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."],
        ["Simple laceration of the left hand without foreign body."],
    ]
).launch(server_name="0.0.0.0", server_port=7860)