import functools import gc import json import logging import os from pathlib import Path try: import spaces except ModuleNotFoundError: spaces = lambda: None spaces.GPU = lambda fn: fn import gradio as gr import tiktoken import torch from openai import OpenAI from transformers import AutoModelForCausalLM, AutoTokenizer from igcs import grounding from igcs.entities import Doc, Selection from igcs.utils import log from igcs.utils.diskcache import disk_cache logger = logging.getLogger("igcs-demo") _EXAMPLES_DIR = Path(__file__).parent # In this simulation, we store only a single document although multi-document is possible. # taken from https://en.wikipedia.org/wiki/Barack_Obama with open(_EXAMPLES_DIR / "barack_obama_wiki.txt", encoding="utf8") as fp: DEFAULT_TEXT = fp.read().strip() # This is the global doc in this demo DEFAULT_PROMPTS = ( "Select content that details Obama's initiatives", "Select content that discusses Obama's personal life", "Select content that details Obama's education", "Select content with Obama's financial data", ) # see src/igcs/prompting.py for more info PROMPT_TEMPLATE = ( "Given the following document(s), {selection_instruction}. " "Output the exact text phrases from the given document(s) as a valid json array of strings. Do not change the copied text.\n\n" "Document #0:\n{doc.text}\n" ) MODELS_LIST = [ # local models: ("====== IGCS Fine-tuned SLMs ======", None), ("Qwen2.5-3b-GenCS-union (local)", "shmuelamar/Qwen2.5-3b-GenCS-union"), ("Qwen2.5-3b-GenCS-majority (local)", "shmuelamar/Qwen2.5-3b-GenCS-majority"), ("Qwen2.5-7b-GenCS-union (local)", "shmuelamar/Qwen2.5-7b-GenCS-union"), ("Qwen2.5-7b-GenCS-majority (local)", "shmuelamar/Qwen2.5-7b-GenCS-majority"), ("Llama-3-8B-GenCS-union (local)", "shmuelamar/Llama-3-8B-GenCS-union"), ("Llama-3-8B-GenCS-majority (local)", "shmuelamar/Llama-3-8B-GenCS-majority"), ("SmolLM2-1.7B-GenCS-union (local)", "shmuelamar/SmolLM2-1.7B-GenCS-union"), ("SmolLM2-1.7B-GenCS-majority (local)", "shmuelamar/SmolLM2-1.7B-GenCS-majority"), ("====== Zero-shot SLMs ======", None), ("Qwen/Qwen2.5-3B-Instruct (local)", "Qwen/Qwen2.5-3B-Instruct"), ("Qwen/Qwen2.5-7B-Instruct (local)", "Qwen/Qwen2.5-7B-Instruct"), ("meta-llama/Meta-Llama-3-8B-Instruct (local)", "meta-llama/Meta-Llama-3-8B-Instruct"), ("HuggingFaceTB/SmolLM2-1.7B-Instruct (local)", "HuggingFaceTB/SmolLM2-1.7B-Instruct"), ("====== API-based Models (OpenRouter) ======", None), ("qwen/qwen3-14b (API)", "api:qwen/qwen3-14b:free"), ("moonshotai/kimi-k2 (API)", "api:moonshotai/kimi-k2:free"), ("deepseek/deepseek-chat-v3-0324 (API)", "api:deepseek/deepseek-chat-v3-0324:free"), ("meta-llama/llama-3.3-70b-instruct (API)", "api:meta-llama/llama-3.3-70b-instruct:free"), ("meta-llama/llama-3.1-405b-instruct (API)", "api:meta-llama/llama-3.1-405b-instruct:free"), ] DEFAULT_MODEL = MODELS_LIST[1][1] MAX_INPUT_TOKENS = 4500 MAX_PROMPT_TOKENS = 256 INTRO_TEXT = """ ## 🚀 Welcome to the IGCS Live Demo! This is a demo for the paper titled [**“A Unifying Scheme for Extractive Content Selection Tasks”**][arxiv-paper] — try Instruction‑Guided Content Selection on **any** text or code: use the demo text or upload your document, enter an instruction, choose a model, and hit **Submit** to see the most relevant spans highlighted! 🔍 Learn more in our [paper][arxiv-paper] and explore the full [GitHub repo](https://github.com/shmuelamar/igcs) ⭐. Enjoy! 🎉 [arxiv-paper]: http://arxiv.org/abs/2507.16922 "A Unifying Scheme for Extractive Content Selection Tasks" """ @spaces.GPU def completion(prompt: str, model_id: str): # load model and tokenizer logger.info(f"loading local model and tokenizer for {model_id}") tokenizer = AutoTokenizer.from_pretrained(model_id) dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="auto") logger.info(f"done loading {model_id}") # tokenize input_ids = tokenizer.apply_chat_template( [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ], add_generation_prompt=True, return_tensors="pt", ).to(model.device) # MPS (on Mac) requires manual attention mask attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=model.device) logger.info(f"generating completion with model_id: {model.name_or_path} and prompt: {prompt!r}") outputs = model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=2048, # eos_token_id=[tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]], do_sample=False, top_k=None, top_p=None, temperature=None, ) # decode response resp = tokenizer.decode(outputs[0][input_ids.shape[-1] :], skip_special_tokens=True) # cleanup memory del model, tokenizer torch.cuda.empty_cache() gc.collect() return resp def completion_openrouter(prompt: str, model_id: str): logger.info(f"calling openrouter with model_id: {model_id} and prompt: {prompt!r}") client = load_openrouter_client() resp = client.chat.completions.create( model=model_id, messages=[{"role": "user", "content": prompt}], temperature=0.0, ) return resp.choices[0].message.content def load_openrouter_client(): logger.info(f"connecting to OpenRouter") return OpenAI( base_url="https://openrouter.ai/api/v1", api_key=os.environ.get("OPENROUTER_API_KEY"), ) @disk_cache(cache_dir=_EXAMPLES_DIR / "models-cache") def get_completion_cache(*, prompt: str, model_id: str) -> str: return get_completion(prompt=prompt, model_id=model_id) @functools.lru_cache(maxsize=2048) def get_completion(*, prompt: str, model_id: str): if model_id.startswith("api:"): return completion_openrouter(prompt, model_id.removeprefix("api:")) else: resp = completion(prompt, model_id) return resp TIKTOKEN_TOKENIZER = tiktoken.encoding_for_model("gpt-4") def count_tokens(text: str) -> int: return len(TIKTOKEN_TOKENIZER.encode(text)) def perform_igcs( doc: Doc, selection_instruction: str, model_id: str ) -> tuple[list[Selection] | None, str]: logger.info(f"performing selection with {selection_instruction!r} using {model_id!r}") prompt = PROMPT_TEMPLATE.format(doc=doc, selection_instruction=selection_instruction) # For the example inputs - we cache from disk as they are more popular if doc.text == DEFAULT_TEXT and selection_instruction in DEFAULT_PROMPTS: logger.info("using disk_cache mode") resp = get_completion_cache(prompt=prompt, model_id=model_id) else: resp = get_completion(prompt=prompt, model_id=model_id) logger.info(f"Got response from model: {model_id}: {resp!r}") # First, parse the selections as json array of strings selection_spans = grounding.parse_selection(resp) # Next, ground them to specific character positions in the source documents selections = grounding.ground_selections(selection_spans, docs=[doc]) logger.info(f"model selections: {selections!r}") return selections, resp def convert_selections_to_gradio_highlights(selections, doc) -> list[tuple[str, str | None]]: pos = 0 highlights = [] # add hallucinations outside the text itself: if any(sel.doc_id == -1 for sel in selections): highlights.append( ("\n\nHallucinated selections (not found in the document):\n\n", "hallucination") ) for sel in selections: if sel.doc_id != -1: # not hallucination continue highlights.append((sel.content + "\n", "hallucination")) selections.sort(key=lambda sel: (sel.end_pos, sel.start_pos)) for sel in selections: if sel.doc_id == -1: continue # hallucination if pos < sel.start_pos: highlights.append((doc.text[pos : sel.start_pos], None)) # outside selection elif pos >= sel.end_pos: continue # two selections overlap - we only display the first. highlights.append( (doc.text[sel.start_pos : sel.end_pos], sel.metadata["mode"]) ) # the selection pos = sel.end_pos if pos + 1 < len(doc.text): highlights.append((doc.text[pos:], None)) # end of the text return highlights def process_igcs_request(selection_instruction: str, model_id: str, doc_data: list[dict]): if model_id is None: raise gr.Error("Please select a valid model from the list.") doc_text = "".join( [doc["token"] for doc in doc_data if doc["class_or_confidence"] != "hallucination"] ) if count_tokens(doc_text) > MAX_INPUT_TOKENS: raise gr.Error( f"File too large! currently only up-to {MAX_INPUT_TOKENS} tokens are supported" ) if count_tokens(selection_instruction) > MAX_PROMPT_TOKENS: raise gr.Error(f"Prompt is too long! only supports up-to {MAX_PROMPT_TOKENS} tokens.") # Perform content selection # TODO: cache examples doc = Doc(id=0, text=doc_text) selections, model_resp = perform_igcs(doc, selection_instruction, model_id) if selections is None: raise gr.Error( "Cannot parse selections, model response is invalid. please try another instruction or model." ) # Post-process selections for display as highlighted spans highlights = convert_selections_to_gradio_highlights(selections, doc) selections_text = json.dumps([s.model_dump(mode="json") for s in selections], indent=2) return highlights, model_resp, selections_text def get_app() -> gr.Interface: with gr.Blocks(title="Instruction-guided content selection", theme="ocean", head="") as app: with gr.Row(): gr.Markdown(INTRO_TEXT) with gr.Row(equal_height=True): with gr.Column(scale=2, min_width=300): prompt_text = gr.Dropdown( label="Content Selection Instruction:", info='Choose an existing instruction or write a short one, starting with "Select content" or "Select code".', value=DEFAULT_PROMPTS[0], choices=DEFAULT_PROMPTS, multiselect=False, allow_custom_value=True, ) with gr.Column(scale=1, min_width=200): model_selector = gr.Dropdown( label="Choose a Model", info="Choose a model from the predefined list below.", value=DEFAULT_MODEL, choices=MODELS_LIST, multiselect=False, allow_custom_value=False, ) with gr.Row(): submit_button = gr.Button("Submit", variant="primary") upload_button = gr.UploadButton("Upload a text or code file", file_count="single") reset_button = gr.Button("Default text") with gr.Row(): with gr.Accordion("Detailed response", open=False): model_resp_text = gr.Code( label="Model's raw response", interactive=False, value="No response yet", lines=5, language="json", ) model_selections_text = gr.Code( label="Grounded selections", interactive=False, value="No response yet", lines=10, language="json", ) with gr.Row(): highlighted_text = gr.HighlightedText( label="Selected Content", value=[(DEFAULT_TEXT, None), ("", "exact_match")], combine_adjacent=False, show_legend=True, interactive=False, color_map={ "exact_match": "lightgreen", "normalized_match": "green", "fuzzy_match": "yellow", "hallucination": "red", }, ) def upload_file(filepath): with open(filepath, "r", encoding="utf8") as fp: text = fp.read().strip() if count_tokens(text) > MAX_INPUT_TOKENS: raise gr.Error( f"File too large! currently only up-to {MAX_INPUT_TOKENS} tokens are supported" ) return [(text, None), ("", "exact_match")] def reset_text(*args): return [(DEFAULT_TEXT, None), ("", "exact_match")] upload_button.upload(upload_file, upload_button, outputs=[highlighted_text]) submit_button.click( process_igcs_request, inputs=[prompt_text, model_selector, highlighted_text], outputs=[highlighted_text, model_resp_text, model_selections_text], ) reset_button.click(reset_text, reset_button, outputs=[highlighted_text]) return app if __name__ == "__main__": log.init() logger.info("starting app") app = get_app() app.queue() app.launch() logger.info("done")