# 📚 Install dependencies # Make sure to run this in your environment if you haven't already # !pip install openai anthropic google-generativeai gradio transformers torch gliner numpy pandas --quiet # ⚙️ Imports import openai import anthropic import google.generativeai as genai import gradio as gr from gliner import GLiNER from collections import defaultdict import numpy as np import pandas as pd import os import tempfile # 🧠 Supported models and their providers MODEL_OPTIONS = { "OpenAI (GPT-4o)": "openai", "Anthropic (Claude 3 Opus)": "anthropic", "Google (Gemini 1.5 Pro)": "google" } # 🔧 GLiNER Model Configuration GLINER_MODEL_NAME = "urchade/gliner_large-v2.1" # --- Load the model only once at startup --- try: print("Loading Extraction AI (GLiNER model)... This may take a moment.") gliner_model = GLiNER.from_pretrained(GLINER_MODEL_NAME) print("Extraction AI loaded successfully.") except Exception as e: print(f"FATAL ERROR: Could not load GLiNER model. The app will not be able to find entities. Error: {e}") gliner_model = None # 🧠 Prompt for the Conceptual AI to generate a research framework FRAMEWORK_PROMPT_TEMPLATE = """ You are an expert research assistant specializing in history. For the provided topic: **"{topic}"**, your task is to generate a conceptual research framework. **Instructions:** 1. Identify 4-6 high-level **Conceptual Categories** relevant to analyzing this historical topic (e.g., 'Key Figures', 'Core Ideologies', 'Significant Events'). 2. For each category, list specific, searchable **Labels** that would appear in a primary or secondary source document. 3. **Crucial Rule for Labels:** Use concise, singular, and fundamental terms (e.g., use `Treaty` not `Diplomatic Treaties`). Use Title Case (e.g. `Working Class`). **Output Format:** Use Markdown. Each category must be a Level 3 Header (###), followed by a comma-separated list of its labels. ### Example Category: Political Actions - Petition, Charter, Protest, Rally, Legislation ### Example Category: Social Groups - Working Class, Aristocracy, Clergy """ # 🧠 Generator Function (The "Conceptual AI") def generate_from_prompt(prompt, provider, key_dict): provider_id = MODEL_OPTIONS.get(provider) api_key = key_dict.get(f"{provider_id}_key") if not api_key: raise ValueError(f"API key for {provider} not found.") if provider_id == "openai": client = openai.OpenAI(api_key=api_key) response = client.chat.completions.create(model="gpt-4o", messages=[{"role": "user", "content": prompt}], temperature=0.2) return response.choices[0].message.content.strip() elif provider_id == "anthropic": client = anthropic.Anthropic(api_key=api_key) response = client.messages.create(model="claude-3-opus-20240229", max_tokens=1024, messages=[{"role": "user", "content": prompt}]) return response.content[0].text.strip() elif provider_id == "google": genai.configure(api_key=api_key) model = genai.GenerativeModel('gemini-1.5-pro-latest') response = model.generate_content(prompt) return response.text.strip() return "" # --- UI Definitions --- STANDARD_LABELS = [ "Person", "Organisation", "Location", "Country", "City", "State", "Nationality", "Group", "Date", "Event", "Law", "Legal Document", "Product", "Facility", "Work of Art", "Language", "Time", "Percentage", "Money", "Currency", "Quantity", "Ordinal Number", "Cardinal Number" ] MAX_CATEGORIES = 8 with gr.Blocks(title="Historical Text Analyser", css=".prose { word-break: break-word; }") as demo: gr.Markdown("# Historical Text Analyser") gr.Markdown(""" First, a **Conceptual AI**, powered by a generative AI Large Language Model (LLM) such as OpenAI's GPT-4, Anthropic's Claude, or Google's Gemini, suggests labels based on your chosen historical topic. These labels are grouped into broader categories (e.g. Economic Policies, Significant Events) to help focus your research. Second, an **Extraction AI**, powered by the GLiNER model, scans your source text to find and highlight matching entities - instances where those labels appear in the document - with high accuracy. ### Understanding Entities and Labels ### In text analysis, this process is often called Named Entity Recognition (NER). - An **Entity** is a specific piece of text in your document, such as a name, a place, or a date (e.g. Queen Victoria, 1848). - A **Label** is the category that the entity belongs to (e.g. Person, Date, Location). This tool helps you to define your labels and then finds the corresponding entities in your text. """) gr.Markdown("--- \n## Step 1: Generate Labels") with gr.Row(): topic = gr.Textbox(label="Enter a Historical Topic", placeholder="e.g. Britain during the Second World War") provider = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Choose AI Model") with gr.Row(): openai_key = gr.Textbox(label="OpenAI API Key", type="password") anthropic_key = gr.Textbox(label="Anthropic API Key", type="password") google_key = gr.Textbox(label="Google API Key", type="password") generate_btn = gr.Button("Generate Labels", variant="primary") gr.Markdown("--- \n## Step 2: Confirm Labels and Analyse Source Text") gr.Markdown("#### 1. AI-Suggested Labels") dynamic_components = [] with gr.Column(): for i in range(MAX_CATEGORIES): with gr.Accordion(f"Suggested Category {i+1}", visible=False) as acc: cg = gr.CheckboxGroup(label="Labels in this category", interactive=True) with gr.Row(): select_btn = gr.Button("Select All", size="sm") deselect_btn = gr.Button("Deselect All", size="sm") dynamic_components.append((acc, cg, select_btn, deselect_btn)) gr.Markdown("#### 2. Standard Labels (Optional)") with gr.Group(): standard_labels_checkbox = gr.CheckboxGroup(choices=STANDARD_LABELS, value=STANDARD_LABELS, label="Standard Entity Labels") with gr.Row(): select_all_std_btn = gr.Button("Select All", size="sm") deselect_all_std_btn = gr.Button("Deselect All", size="sm") gr.Markdown("#### 3. Custom Labels (Optional)") with gr.Group(): custom_labels_textbox = gr.Textbox(label="Enter Custom Labels (separate with commas)", placeholder="e.g., Technology, Weapon, Secret Society...") gr.Markdown("--- \n## Step 3: Run Analysis") threshold_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold") text_input = gr.Textbox(label="Paste Your Text Here for Analysis", lines=15) analyze_btn = gr.Button("Find Entities", variant="primary") analysis_status = gr.Markdown(visible=False) gr.Markdown("--- \n## Step 4: Review Results") gr.Markdown("✨ **Pro Tip:** In the **\"Highlighted Text\"** view, you can click and drag to highlight text and create your own labels!") with gr.Tabs(): with gr.TabItem("Highlighted Text"): highlighted_text_output = gr.HighlightedText(label="Found Entities", interactive=True) with gr.TabItem("Detailed Results"): gr.Markdown("You can sort the table by clicking on column headers. The download link for the full table will appear here after analysis.") # MOVED and RECONFIGURED: The file download link now sits here permanently. csv_file_output = gr.File(label="Download Results as CSV", visible=False) detailed_results_output = gr.DataFrame(headers=["Label", "Text Found", "Instances", "Confidence Score"], datatype=["str", "str", "number", "number"], label="Aggregated List of Found Entities") with gr.TabItem("Debug Log"): debug_output = gr.Textbox(label="Extraction Process Log", interactive=False, lines=8) # --- Backend Functions --- def handle_generate(topic, provider, openai_k, anthropic_k, google_k): yield {generate_btn: gr.update(value="Generating...", interactive=False)} try: key_dict = {"openai_key": os.environ.get("OPENAI_API_KEY", openai_k), "anthropic_key": os.environ.get("ANTHROPIC_API_KEY", anthropic_k), "google_key": os.environ.get("GOOGLE_API_KEY", google_k)} provider_id = MODEL_OPTIONS.get(provider) if not topic or not provider or not key_dict.get(f"{provider_id}_key"): raise gr.Error("A topic, provider, and valid API Key for that provider are required.") prompt = FRAMEWORK_PROMPT_TEMPLATE.format(topic=topic) raw_framework = generate_from_prompt(prompt, provider, key_dict) framework = defaultdict(list) current_category = None for line in raw_framework.split('\n'): line = line.strip() if line.startswith("###"): current_category = line.replace("###", "").strip() elif line.startswith("-") and current_category: entities_string = line.replace("-", "").strip() framework[current_category].extend([e.strip() for e in entities_string.split(',') if e.strip()]) if not framework: raise gr.Error("The AI failed to generate categories. Please try again or rephrase your topic.") updates = {} categories = list(framework.items()) for i in range(MAX_CATEGORIES): acc, cg, sel, desel = dynamic_components[i] if i < len(categories): cat_name, entities = categories[i] sorted_entities = sorted(list(set(entities))) updates[acc] = gr.update(label=f"Category: {cat_name}", visible=True) updates[cg] = gr.update(choices=sorted_entities, value=sorted_entities, label="Suggested Labels", visible=True) updates[sel] = gr.update(visible=True) updates[desel] = gr.update(visible=True) else: updates[acc] = gr.update(visible=False); updates[cg] = gr.update(choices=[], value=[], visible=False); updates[sel] = gr.update(visible=False); updates[desel] = gr.update(visible=False) updates[generate_btn] = gr.update(value="Generate Labels", interactive=True) yield updates except Exception as e: yield {generate_btn: gr.update(value="Generate Labels", interactive=True)} raise gr.Error(str(e)) def analyze_text(text, standard_labels, custom_label_text, threshold, *suggested_labels_from_groups): yield { analyze_btn: gr.update(value="Finding Entities...", interactive=False), analysis_status: gr.update(value="The Extraction AI is scanning your text...", visible=True), highlighted_text_output: None, detailed_results_output: None, debug_output: "Starting analysis...", csv_file_output: gr.update(visible=False, value=None) } if gliner_model is None: raise gr.Error("Extraction AI (GLiNER model) is not loaded.") labels_to_use = set() for group in suggested_labels_from_groups: if group: labels_to_use.update(group) if standard_labels: labels_to_use.update(standard_labels) custom = {l.strip() for l in custom_label_text.split(',') if l.strip()} if custom: labels_to_use.update(custom) final_labels = sorted(list(labels_to_use)) if not text or not final_labels: yield { analyze_btn: gr.update(value="Find Entities", interactive=True), analysis_status: gr.update(visible=False), highlighted_text_output: {"text": text, "entities": []}, detailed_results_output: None, debug_output: "Analysis stopped: No text or no labels provided.", csv_file_output: gr.update(visible=False, value=None) } return all_entities = [] chunk_size, overlap = 1024, 100 for i in range(0, len(text), chunk_size - overlap): chunk = text[i : i + chunk_size] chunk_entities = gliner_model.predict_entities(chunk, final_labels, threshold=threshold) for ent in chunk_entities: ent['start'] += i; ent['end'] += i all_entities.append(ent) unique_entities = [dict(t) for t in {tuple(d.items()) for d in all_entities}] highlighted_output_data = {"text": text, "entities": [{"start": ent["start"], "end": ent["end"], "entity": ent["label"]} for ent in unique_entities]} aggregated_matches = defaultdict(lambda: {'count': 0, 'scores': [], 'original_casing': ''}) for ent in unique_entities: match_text = text[ent['start']:ent['end']] key = (ent['label'], match_text.lower()) aggregated_matches[key]['count'] += 1 aggregated_matches[key]['scores'].append(ent['score']) if not aggregated_matches[key]['original_casing']: aggregated_matches[key]['original_casing'] = match_text table_rows = [] for (label, _), data in aggregated_matches.items(): avg_score = np.mean(data['scores']) table_rows.append({"Label": label, "Text Found": data['original_casing'], "Instances": data['count'], "Confidence Score": round(avg_score, 2)}) results_df = pd.DataFrame(table_rows) if not results_df.empty: results_df = results_df.sort_values(by=["Label", "Instances"], ascending=[True, False]) # --- REFACTORED: Create and return the CSV file directly --- csv_file_path = None if not results_df.empty: with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.csv', encoding='utf-8') as tmpfile: results_df.to_csv(tmpfile.name, index=False) csv_file_path = tmpfile.name yield { analyze_btn: gr.update(value="Find Entities", interactive=True), analysis_status: gr.update(visible=False), highlighted_text_output: highlighted_output_data, detailed_results_output: results_df, debug_output: "Analysis complete.", csv_file_output: gr.update(value=csv_file_path, visible=bool(csv_file_path)) } # --- Wire up UI events --- generate_btn.click(fn=handle_generate, inputs=[topic, provider, openai_key, anthropic_key, google_key], outputs=[generate_btn] + [c for p in dynamic_components for c in p]) def deselect_all(): return gr.update(value=[]) deselect_all_std_btn.click(fn=deselect_all, inputs=None, outputs=[standard_labels_checkbox]) select_all_std_btn.click(lambda: gr.update(value=STANDARD_LABELS), inputs=None, outputs=[standard_labels_checkbox]) for _, cg, sel_btn, desel_btn in dynamic_components: sel_btn.click(fn=lambda c=cg: gr.update(value=c.choices), inputs=None, outputs=[cg]) desel_btn.click(fn=deselect_all, inputs=None, outputs=[cg]) analyze_btn.click( fn=analyze_text, inputs=[text_input, standard_labels_checkbox, custom_labels_textbox, threshold_slider] + [cg for acc, cg, sel, desel in dynamic_components], # The outputs list is now cleaner outputs=[analyze_btn, analysis_status, highlighted_text_output, detailed_results_output, debug_output, csv_file_output] ) demo.launch(share=True, debug=True)