import pandas as pd import torch import os from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM from huggingface_hub import login from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.embeddings import SentenceTransformerEmbeddings from langchain_community.vectorstores import FAISS from langchain_classic.chains import RetrievalQA from langchain_core.prompts import PromptTemplate from langchain_community.llms import HuggingFacePipeline from langchain_community.document_loaders.csv_loader import CSVLoader import transformers from langchain_core.documents import Document import gradio as gr import re model = "abnuel/MedGemma-4b-ICD" #tokenizer = AutoTokenizer.from_pretrained("abnuel/MedGemma-4b-ICD") SYSTEM_PROMPT = "You are an expert medical coder. Your task is to analyze the clinical description provided and output only the single, most appropriate ICD-10-CM code. Do not include any text, justification other than the code itself." bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token OFFLOAD_FOLDER = "model_offload_dir" model = AutoModelForCausalLM.from_pretrained( model, quantization_config=bnb_config, device_map="auto", offload_folder=OFFLOAD_FOLDER ) def generate_response(clinical_note): messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"Code the following: {clinical_note}"}, ] # 3. Apply chat template and tokenize inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device) input_len = inputs["input_ids"].shape[-1] # 4. Generate the response with torch.inference_mode(): generation = model.generate( **inputs, max_new_tokens=200, # Max length of the generated ICD codes do_sample=False, # Use greedy decoding for predictable output temperature=0.0, # Zero temperature for deterministic results ) # 5. Decode the output # Extract only the newly generated tokens generation = generation[0][input_len:] decoded_output = tokenizer.decode(generation, skip_special_tokens=True) return decoded_output.strip() # --- Example Usage --- #test_note = "Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina." #print(f"Clinical Note: {test_note}") #response = generate_response(test_note) #print(f"Generated ICD Codes: {response}") pipe = transformers.pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=50, temperature=0.1, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) hf_llm = HuggingFacePipeline(pipeline=pipe) df = pd.read_csv("./medical_coding_train_1.csv") documents = [ Document( page_content=f"note: {row['note']}\nicd_code: {row['icd_codes']}", metadata={"icd_code": row["icd_codes"]} ) for _, row in df.iterrows() ] # 2. Chunk Documents text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) docs = text_splitter.split_documents(documents) # 3. Create Embeddings and Vector Store (FAISS) embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") db = FAISS.from_documents(docs, embeddings) retriever = db.as_retriever(search_kwargs={"k": 2}) RAG_PROMPT_TEMPLATE = """ You are an expert medical coder. Your task is to determine the most accurate ICD-10-CM code for the given clinical note. Use ONLY the following context (which may include ICD codes from similar cases). If you cannot determine a match from the provided context, respond exactly with: "I cannot find the code in the provided documents." Return ONLY the ICD-10-CM code itself — no explanation, no text, no punctuation. Context: {context} Clinical Note: {question} ICD-10-CM Code: """ rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE) #direct_chain = LLMChain(llm=hf_llm, prompt=rag_prompt) # 5. Create the QA Chain qa_chain = RetrievalQA.from_chain_type( llm=hf_llm, chain_type="stuff", retriever=retriever, return_source_documents=False, chain_type_kwargs={"prompt": rag_prompt} ) def extract_icd_code(text): # Pattern to match "ICD-10-CM Code:" followed by the code pattern = r'ICD-10-CM Code:\s*([A-Z0-9.]+)' match = re.search(pattern, text) if match: return match.group(1) return None def generate_code_rag(clinical_note, retriever, threshold=0.35): """Generates the ICD code using RAG.""" # Format the user question for the RAG prompt template query = f"Code the following: {clinical_note}" # Step 1: Retrieve docs docs_and_scores = db.similarity_search_with_score(query, k=2) # Step 2: Filter by similarity threshold relevant_docs = [doc for doc, score in docs_and_scores if score > threshold] if relevant_docs: #print(qa_chain) result = qa_chain({"query": query})["result"] #answer = result['result'] icd_code = extract_icd_code(result) #print(icd_code) if icd_code == None: print("I got here") result = generate_response(clinical_note) return result else: return icd_code # Step 4: Otherwise, use LLM directly (no context) # Create the Gradio Interface gr.Interface( fn=generate_code_rag, inputs=gr.Textbox(lines=5, label="Enter Clinical Note Here", placeholder="e.g., Patient presented with simple laceration of the left hand."), outputs=gr.Textbox(label="Predicted ICD-10 Code"), title="ClaimSwift Medical Coding", description="", examples=[ ["Benign neoplasm of peripheral nerves and autonomic nervous system of face, head, and neck"], ["Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."], ["Simple laceration of the left hand without foreign body."], ] ).launch(server_name="0.0.0.0", server_port=7860)