Spaces:
Sleeping
Sleeping
File size: 6,404 Bytes
1d53ee2 89affd8 b804e50 03aaeb7 06698e0 aa75487 03aaeb7 1d53ee2 aa75487 1d53ee2 be85ca1 1d53ee2 89b9ccd 1d53ee2 89b9ccd 1d53ee2 89affd8 1d53ee2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import pandas as pd
import torch
import os
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
from huggingface_hub import login
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_classic.chains import RetrievalQA
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import HuggingFacePipeline
from langchain_community.document_loaders.csv_loader import CSVLoader
import transformers
from langchain_core.documents import Document
import gradio as gr
import re
model = "abnuel/MedGemma-4b-ICD"
#tokenizer = AutoTokenizer.from_pretrained("abnuel/MedGemma-4b-ICD")
SYSTEM_PROMPT = "You are an expert medical coder. Your task is to analyze the clinical description provided and output only the single, most appropriate ICD-10-CM code. Do not include any text, justification other than the code itself."
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
OFFLOAD_FOLDER = "model_offload_dir"
model = AutoModelForCausalLM.from_pretrained(
model,
quantization_config=bnb_config,
device_map="auto",
offload_folder=OFFLOAD_FOLDER
)
def generate_response(clinical_note):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Code the following: {clinical_note}"},
]
# 3. Apply chat template and tokenize
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# 4. Generate the response
with torch.inference_mode():
generation = model.generate(
**inputs,
max_new_tokens=200, # Max length of the generated ICD codes
do_sample=False, # Use greedy decoding for predictable output
temperature=0.0, # Zero temperature for deterministic results
)
# 5. Decode the output
# Extract only the newly generated tokens
generation = generation[0][input_len:]
decoded_output = tokenizer.decode(generation, skip_special_tokens=True)
return decoded_output.strip()
# --- Example Usage ---
#test_note = "Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."
#print(f"Clinical Note: {test_note}")
#response = generate_response(test_note)
#print(f"Generated ICD Codes: {response}")
pipe = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=50,
temperature=0.1,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
hf_llm = HuggingFacePipeline(pipeline=pipe)
df = pd.read_csv("./medical_coding_train_1.csv")
documents = [
Document(
page_content=f"note: {row['note']}\nicd_code: {row['icd_codes']}",
metadata={"icd_code": row["icd_codes"]}
)
for _, row in df.iterrows()
]
# 2. Chunk Documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = text_splitter.split_documents(documents)
# 3. Create Embeddings and Vector Store (FAISS)
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
db = FAISS.from_documents(docs, embeddings)
retriever = db.as_retriever(search_kwargs={"k": 2})
RAG_PROMPT_TEMPLATE = """
You are an expert medical coder.
Your task is to determine the most accurate ICD-10-CM code for the given clinical note.
Use ONLY the following context (which may include ICD codes from similar cases).
If you cannot determine a match from the provided context, respond exactly with:
"I cannot find the code in the provided documents."
Return ONLY the ICD-10-CM code itself — no explanation, no text, no punctuation.
Context:
{context}
Clinical Note:
{question}
ICD-10-CM Code:
"""
rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
#direct_chain = LLMChain(llm=hf_llm, prompt=rag_prompt)
# 5. Create the QA Chain
qa_chain = RetrievalQA.from_chain_type(
llm=hf_llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=False,
chain_type_kwargs={"prompt": rag_prompt}
)
def extract_icd_code(text):
# Pattern to match "ICD-10-CM Code:" followed by the code
pattern = r'ICD-10-CM Code:\s*([A-Z0-9.]+)'
match = re.search(pattern, text)
if match:
return match.group(1)
return None
def generate_code_rag(clinical_note, retriever, threshold=0.35):
"""Generates the ICD code using RAG."""
# Format the user question for the RAG prompt template
query = f"Code the following: {clinical_note}"
# Step 1: Retrieve docs
docs_and_scores = db.similarity_search_with_score(query, k=2)
# Step 2: Filter by similarity threshold
relevant_docs = [doc for doc, score in docs_and_scores if score > threshold]
if relevant_docs:
#print(qa_chain)
result = qa_chain({"query": query})["result"]
#answer = result['result']
icd_code = extract_icd_code(result)
#print(icd_code)
if icd_code == None:
print("I got here")
result = generate_response(clinical_note)
return result
else:
return icd_code
# Step 4: Otherwise, use LLM directly (no context)
# Create the Gradio Interface
gr.Interface(
fn=generate_code_rag,
inputs=gr.Textbox(lines=5, label="Enter Clinical Note Here", placeholder="e.g., Patient presented with simple laceration of the left hand."),
outputs=gr.Textbox(label="Predicted ICD-10 Code"),
title="ClaimSwift Medical Coding",
description="",
examples=[
["Benign neoplasm of peripheral nerves and autonomic nervous system of face, head, and neck"],
["Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."],
["Simple laceration of the left hand without foreign body."],
]
).launch(server_name="0.0.0.0", server_port=7860)
|