# Introspective Prisma-VL-8B Architecture ## Overview Prisma-VL-8B includes a introspective feedback mechanism that provides fine-grained self-monitoring uncertainty awareness to the model's predictions. ## Core Innovation The model now tracks its own prediction uncertainty and uses this as a feedback signal for subsequent predictions. This creates a temporal awareness loop: ``` Token t-1: "What's next?" → Prediction + Uncertainty measurement Token t: [Previous uncertainty signal] + "What's next?" → Better calibrated prediction ``` ## Architecture Changes ### 1. Uncertainty Embeddings (PrismaVLModel) Added to `PrismaVLModel.__init__()`: ```python # 65,536-level uncertainty embedding table self.n_bits = 16 # 16-bit quantization self.n_uncertainty_levels = 65536 # 2^16 # Learned embeddings: one vector per uncertainty level self.uncertainty_embeddings = nn.Embedding(65536, hidden_dim) # Cache for uncertainty codes from previous step self.prev_uncertainty_code = None # [batch_size, seq_len] with values [0-65535] ``` **Parameter cost**: 65,536 × 4096 = 268,435,456 parameters (3.35% overhead) ### 2. Uncertainty Injection (PrismaVLModel.forward) During forward pass, after creating input embeddings: ```python # Look up uncertainty embeddings from previous step uncertainty_embeds = self.uncertainty_embeddings(prev_uncertainty_code) # Shift right: position i gets uncertainty from position i-1 uncertainty_shifted = pad(uncertainty_embeds[:, :-1, :], (0,0,1,0)) # Inject into input inputs_embeds = inputs_embeds + uncertainty_shifted ``` Now the model sees: **[Token embedding] + [How uncertain was I last time?]** ### 3. Uncertainty Computation (PrismaVLForConditionalGeneration.forward) After computing logits, during training: ```python # Compute entropy (uncertainty) of predictions probs = logits.softmax(-1) entropy = -(probs * log(probs)).sum(-1) # Normalize to [0, 1] entropy_norm = entropy / log(vocab_size) # Quantize to 16 bits (0-65535) uncertainty_code = (entropy_norm * 65535).long() # Store for next step self.model.prev_uncertainty_code = uncertainty_code ``` ## How It Works (Step by Step) ### Inference/Generation: 1. **Token 0**: No previous uncertainty → Use neutral (32768) 2. **Token 1**: Predict → Measure confidence → Encode as 0-65535 3. **Token 2**: Inject uncertainty signal from Token 1 → Predict (now calibrated) 4. **Token 3**: Inject uncertainty from Token 2 → Predict 5. ... and so on ### Training: Model learns the uncertainty embeddings through backpropagation: - Embedding #0-16383: "I was very confident" → Model learns to stay confident - Embedding #16384-32767: "I had medium confidence" → Model learns moderate caution - Embedding #32768-49151: "I was uncertain" → Model learns to hedge - Embedding #49152-65535: "I was very uncertain" → Model learns to be conservative ## Key Properties ### 1. Moderate Overhead - **Parameters**: 268M additional (3.35% of 8B base) - **Memory**: 2 bytes per token (uncertainty code) - **Compute**: Negligible (one embedding lookup per token) ### 2. Temporal Awareness - Model builds a "confidence history" across generation - Can detect when it's going into unfamiliar territory - Can recover calibration after uncertain predictions ### 3. Self-Calibration - No external signals needed - Model learns its own uncertainty language - Improves through standard supervised training ### 4. Architecture-Agnostic - Works with any transformer-based model - Doesn't modify attention, FFN, or other core components - Clean separation: uncertainty mechanism vs. base model ## Usage ### Standard Inference ```python from modeling import PrismaVLForConditionalGeneration from transformers import AutoProcessor # Load model (introspective mechanism is built-in) model = PrismaVLForConditionalGeneration.from_pretrained( ".", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" ) processor = AutoProcessor.from_pretrained(".", trust_remote_code=True) # Use normally - uncertainty tracking happens automatically messages = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": prompt}]}] inputs = processor.apply_chat_template(messages, ...) outputs = model.generate(**inputs) ``` ### Training ```python # Train normally - uncertainty mechanism learns automatically optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5) for batch in dataloader: outputs = model(**batch) loss = outputs.loss loss.backward() optimizer.step() # The uncertainty embeddings will learn to represent # "how to adjust predictions based on previous confidence" ``` ### Resetting Uncertainty (Between Sequences) ```python # Reset uncertainty cache between independent generations model.model.reset_uncertainty() # Generate outputs = model.generate(...) ``` ## What Gets Learned The 65,536 uncertainty embedding vectors learn to encode: 1. **Confidence Continuation**: - "Last token was confident" → Maintain confidence (if appropriate) 2. **Uncertainty Propagation**: - "Last token was uncertain" → Be more conservative 3. **Domain Shifts**: - Sequence of low uncertainty → sudden high uncertainty → Domain boundary detected 4. **Recovery Patterns**: - High uncertainty → Gradual return to confidence → Model finding its footing ## Benefits 1. **Better Calibration**: Model knows when it doesn't know 2. **Hallucination Awareness**: Uncertain predictions less likely to compound 3. **Adaptive Confidence**: Can adjust based on recent performance 4. **Interpretability**: Uncertainty codes provide insight into model state 5. **No Inference Cost**: Only active during training (for computing new uncertainties) ## Implementation Details ### Files Modified - `modeling.py`: - `PrismaVLModel.__init__()`: Add uncertainty embeddings - `PrismaVLModel.forward()`: Inject uncertainty signal - `PrismaVLForConditionalGeneration.forward()`: Compute uncertainty - Added `reset_uncertainty()` method ### Initialization - Uncertainty embeddings initialized with `std = config.text_config.initializer_range` (typically 0.02) - Start neutral: first token uses code 128 (middle of range) ### Compatibility - Fully backward compatible: model can load existing checkpoints - New uncertainty embeddings initialize randomly (will be trained) - No changes to base model weights or architecture ## Comparison to Original Llama 3.2 Example ### Similarities: - Entropy-based uncertainty measurement - Temporal feedback loop - Embedding-based uncertainty representation ### Differences: - **Quantization**: 16-bit (65,536 levels) vs. 8-bit (256 levels) - **Resolution**: Fine-grained uncertainty vs. coarse-grained - **Overhead**: 3.35% parameter overhead vs. ~0.04% - **Applied to**: Vision-language model (Prisma-VL) vs. pure language model (Llama) - **Integration**: Built into core architecture vs. wrapper class - **Scope**: Uncertainty only for text generation (not vision encoding) ## Future Enhancements Potential extensions: 1. **Multi-resolution Uncertainty**: Track uncertainty at token, word, and sentence levels 2. **Uncertainty-aware Generation**: Sample less when uncertain (lower temperature) 3. **Visual Uncertainty**: Extend mechanism to vision encoder 4. **Cross-modal Uncertainty**: Track alignment confidence between vision and text 5. **Explicit Uncertainty Tokens**: Add special tokens to express uncertainty in output ## Citation Inspired by temporal feedback loop patterns, enhanced with 16-bit high-resolution quantization for fine-grained uncertainty representation. --- **Model**: Prisma-VL-8B **Date**: 2025 **Architecture**: Integrated 16-bit temporal uncertainty feedback mechanism **Parameter Overhead**: 268M (3.35%) **Memory Overhead**: 2 bytes/token