Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import torch | |
| import os | |
| from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM | |
| from huggingface_hub import login | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import SentenceTransformerEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_classic.chains import RetrievalQA | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_community.llms import HuggingFacePipeline | |
| from langchain_community.document_loaders.csv_loader import CSVLoader | |
| import transformers | |
| from langchain_core.documents import Document | |
| import gradio as gr | |
| import re | |
| model = "abnuel/MedGemma-4b-ICD" | |
| #tokenizer = AutoTokenizer.from_pretrained("abnuel/MedGemma-4b-ICD") | |
| SYSTEM_PROMPT = "You are an expert medical coder. Your task is to analyze the clinical description provided and output only the single, most appropriate ICD-10-CM code. Do not include any text, justification other than the code itself." | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| OFFLOAD_FOLDER = "model_offload_dir" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| offload_folder=OFFLOAD_FOLDER | |
| ) | |
| def generate_response(clinical_note): | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Code the following: {clinical_note}"}, | |
| ] | |
| # 3. Apply chat template and tokenize | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| input_len = inputs["input_ids"].shape[-1] | |
| # 4. Generate the response | |
| with torch.inference_mode(): | |
| generation = model.generate( | |
| **inputs, | |
| max_new_tokens=200, # Max length of the generated ICD codes | |
| do_sample=False, # Use greedy decoding for predictable output | |
| temperature=0.0, # Zero temperature for deterministic results | |
| ) | |
| # 5. Decode the output | |
| # Extract only the newly generated tokens | |
| generation = generation[0][input_len:] | |
| decoded_output = tokenizer.decode(generation, skip_special_tokens=True) | |
| return decoded_output.strip() | |
| # --- Example Usage --- | |
| #test_note = "Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina." | |
| #print(f"Clinical Note: {test_note}") | |
| #response = generate_response(test_note) | |
| #print(f"Generated ICD Codes: {response}") | |
| pipe = transformers.pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=50, | |
| temperature=0.1, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| hf_llm = HuggingFacePipeline(pipeline=pipe) | |
| df = pd.read_csv("./medical_coding_train_1.csv") | |
| documents = [ | |
| Document( | |
| page_content=f"note: {row['note']}\nicd_code: {row['icd_codes']}", | |
| metadata={"icd_code": row["icd_codes"]} | |
| ) | |
| for _, row in df.iterrows() | |
| ] | |
| # 2. Chunk Documents | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
| docs = text_splitter.split_documents(documents) | |
| # 3. Create Embeddings and Vector Store (FAISS) | |
| embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
| db = FAISS.from_documents(docs, embeddings) | |
| retriever = db.as_retriever(search_kwargs={"k": 2}) | |
| RAG_PROMPT_TEMPLATE = """ | |
| You are an expert medical coder. | |
| Your task is to determine the most accurate ICD-10-CM code for the given clinical note. | |
| Use ONLY the following context (which may include ICD codes from similar cases). | |
| If you cannot determine a match from the provided context, respond exactly with: | |
| "I cannot find the code in the provided documents." | |
| Return ONLY the ICD-10-CM code itself β no explanation, no text, no punctuation. | |
| Context: | |
| {context} | |
| Clinical Note: | |
| {question} | |
| ICD-10-CM Code: | |
| """ | |
| rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE) | |
| #direct_chain = LLMChain(llm=hf_llm, prompt=rag_prompt) | |
| # 5. Create the QA Chain | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=hf_llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=False, | |
| chain_type_kwargs={"prompt": rag_prompt} | |
| ) | |
| def extract_icd_code(text): | |
| # Pattern to match "ICD-10-CM Code:" followed by the code | |
| pattern = r'ICD-10-CM Code:\s*([A-Z0-9.]+)' | |
| match = re.search(pattern, text) | |
| if match: | |
| return match.group(1) | |
| return None | |
| def generate_code_rag(clinical_note, retriever, threshold=0.35): | |
| """Generates the ICD code using RAG.""" | |
| # Format the user question for the RAG prompt template | |
| query = f"Code the following: {clinical_note}" | |
| # Step 1: Retrieve docs | |
| docs_and_scores = db.similarity_search_with_score(query, k=2) | |
| # Step 2: Filter by similarity threshold | |
| relevant_docs = [doc for doc, score in docs_and_scores if score > threshold] | |
| if relevant_docs: | |
| #print(qa_chain) | |
| result = qa_chain({"query": query})["result"] | |
| #answer = result['result'] | |
| icd_code = extract_icd_code(result) | |
| #print(icd_code) | |
| if icd_code == None: | |
| print("I got here") | |
| result = generate_response(clinical_note) | |
| return result | |
| else: | |
| return icd_code | |
| # Step 4: Otherwise, use LLM directly (no context) | |
| # Create the Gradio Interface | |
| gr.Interface( | |
| fn=generate_code_rag, | |
| inputs=gr.Textbox(lines=5, label="Enter Clinical Note Here", placeholder="e.g., Patient presented with simple laceration of the left hand."), | |
| outputs=gr.Textbox(label="Predicted ICD-10 Code"), | |
| title="ClaimSwift Medical Coding", | |
| description="", | |
| examples=[ | |
| ["Benign neoplasm of peripheral nerves and autonomic nervous system of face, head, and neck"], | |
| ["Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."], | |
| ["Simple laceration of the left hand without foreign body."], | |
| ] | |
| ).launch(server_name="0.0.0.0", server_port=7860) | |