YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)
######################################
# INFERENCE MIXTURE OF RECURSIONS V1 #
######################################
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import json


# ----------------------
# MoR Model Components
# ----------------------
class ExpertChoiceRouter(nn.Module):
    """Expert Choice Routing: Experts select top-k tokens"""
    def __init__(self, dim, num_experts, k=2):
        super().__init__()
        self.num_experts = num_experts
        self.k = k
        self.gate = nn.Linear(dim, num_experts, bias=False)
        
    def forward(self, x):
        # x: (batch, seq_len, dim)
        scores = self.gate(x)  # (batch, seq_len, num_experts)
        expert_weights, expert_indices = torch.topk(scores, self.k, dim=-1)
        return expert_weights.softmax(dim=-1), expert_indices

class Quantizer4Bit(nn.Module):
    """4-bit Quantization Utilities"""
    def __init__(self):
        super().__init__()
    
    @staticmethod
    def quantize(tensor):
        """Quantize tensor to 4-bit integers"""
        scale = tensor.abs().max() / 7.5
        scale = torch.clamp(scale, min=1e-8)
        quantized = torch.clamp(torch.round(tensor / scale), -8, 7)
        return quantized.to(torch.int8), scale
    
    @staticmethod
    def dequantize(quantized, scale):
        """Dequantize 4-bit integers to float"""
        return quantized.float() * scale

class QuantizedRecursiveTransformerBlock(nn.Module):
    """Recursive Transformer Block with Quantization"""
    def __init__(self, dim, num_heads, ffn_expansion=4):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        
        # Attention layers
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.attn_out = nn.Linear(dim, dim)
        
        # FFN layers
        self.ffn = nn.Sequential(
            nn.Linear(dim, ffn_expansion * dim),
            nn.GELU(),
            nn.Linear(ffn_expansion * dim, dim)
        )
        
        # Normalization
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
    def forward(self, x):
        # x: (batch, seq_len, dim)
        residual = x
        x = self.norm1(x)
        
        # Projections
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        # Quantize K and V
        k_quant, k_scale = Quantizer4Bit.quantize(k)
        v_quant, v_scale = Quantizer4Bit.quantize(v)
        
        # Dequantize for computation
        k = Quantizer4Bit.dequantize(k_quant, k_scale)
        v = Quantizer4Bit.dequantize(v_quant, v_scale)
        
        # Attention
        B, T, _ = q.shape
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn_out = (attn @ v).transpose(1, 2).contiguous().view(B, T, self.dim)
        attn_out = self.attn_out(attn_out)
        
        # Residual connection
        x = residual + attn_out
        
        # FFN
        x = x + self.ffn(self.norm2(x))
        return x

class RecursionDepthRouter(nn.Module):
    """Lightweight Router for Dynamic Recursion Depth"""
    def __init__(self, dim, max_depth=4):
        super().__init__()
        self.max_depth = max_depth
        self.router = nn.Sequential(
            nn.Linear(dim, 32),
            nn.ReLU(),
            nn.Linear(32, max_depth)
        )
        
    def forward(self, x):
        # x: (batch, seq_len, dim)
        router_logits = self.router(x.mean(dim=1))  # (batch, max_depth)
        return router_logits.softmax(dim=-1)

class QuantizedMoRModel(nn.Module):
    """Main MoR Architecture"""
    def __init__(self, vocab_size, dim, num_layers, num_heads, max_recursion, num_experts, max_position_embeddings):
        super().__init__()
        self.dim = dim
        self.max_recursion = max_recursion
        self.num_experts = num_experts
        self.max_position_embeddings = max_position_embeddings
        
        # Embedding layers
        self.embedding = nn.Embedding(vocab_size, dim)
        self.pos_embed = nn.Embedding(max_position_embeddings, dim)
        
        
        # Initial unique layers
        self.init_layers = nn.ModuleList([
            QuantizedRecursiveTransformerBlock(dim, num_heads)
            for _ in range(2)
        ])
        
        # Middle-cycle shared layers
        self.cycle_depth = 3
        self.recursive_blocks = nn.ModuleList([
            QuantizedRecursiveTransformerBlock(dim, num_heads)
            for _ in range(self.cycle_depth)
        ])
        
        # Recursion routers
        self.recursion_routers = nn.ModuleList([
            RecursionDepthRouter(dim, max_depth=max_recursion)
            for _ in range(num_layers - 4)
        ])
        
        # Expert choice routing
        self.expert_routers = nn.ModuleList([
            ExpertChoiceRouter(dim, num_experts)
            for _ in range(max_recursion)
        ])
        
        # Final unique layers
        self.final_layers = nn.ModuleList([
            QuantizedRecursiveTransformerBlock(dim, num_heads)
            for _ in range(2)
        ])
        
        # Output head
        self.ln_f = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, vocab_size, bias=False)
        
    def forward(self, x):
        # Embedding
        pos = torch.arange(0, x.shape[1], device=x.device)
        x = self.embedding(x) + self.pos_embed(pos)
        
        # Initial unique layers
        for layer in self.init_layers:
            x = layer(x)
        
        # Middle-cycle with recursion
        all_x = [x]
        batch_size, seq_len, _ = x.shape
        
        for router in self.recursion_routers:
            # Get recursion depth probabilities
            depth_probs = router(x)
            
            # Sample recursion depth
            depth = torch.multinomial(depth_probs, 1).squeeze()
            
            # Process through recursive blocks
            for d in range(self.max_recursion):
                # Expert routing
                expert_weights, expert_indices = self.expert_routers[d](x)
                
                # Create full weight matrix
                full_weights = torch.zeros((batch_size, seq_len, self.num_experts), 
                                          device=x.device)
                full_weights.scatter_(2, expert_indices, expert_weights)
                
                # Process each expert
                expert_outputs = []
                for expert_idx in range(self.num_experts):
                    # Get expert mask
                    expert_mask = full_weights[:, :, expert_idx] > 0
                    
                    if expert_mask.any():
                        # Create expert input
                        expert_x = torch.zeros_like(x)
                        expert_x[expert_mask] = x[expert_mask]
                        
                        # Process through block
                        out = self.recursive_blocks[d % self.cycle_depth](expert_x)
                        expert_outputs.append(out * full_weights[:, :, expert_idx].unsqueeze(-1))
                    else:
                        expert_outputs.append(torch.zeros_like(x))
                
                # Combine expert outputs
                x = sum(expert_outputs)
            
            all_x.append(x)
        
        # Combine outputs
        x = torch.stack(all_x).mean(dim=0)
        
        # Final unique layers
        for layer in self.final_layers:
            x = layer(x)
        
        # Output
        x = self.ln_f(x)
        logits = self.head(x)
        return logits

    def generate(self, input_ids, max_length=100, temperature=0.8, top_k=50):
        """Simple text generation function"""
        device = next(self.parameters()).device
        generated = input_ids.clone()
        
        with torch.no_grad():
            for _ in range(max_length):
                # Use max_position_embeddings instead of SEQ_LEN
                inputs = generated[:, -self.max_position_embeddings:] \
                    if generated.shape[1] > self.max_position_embeddings \
                    else generated
                
                # Forward pass
                logits = self(inputs)[:, -1, :]
                
                # Apply temperature
                logits = logits / temperature
                
                # Top-k filtering
                if top_k > 0:
                    top_values, _ = torch.topk(logits, top_k)
                    min_value = top_values[:, -1]
                    logits[logits < min_value.unsqueeze(-1)] = -float('Inf')
                
                # Sample next token
                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1)
                
                # Append to sequence
                generated = torch.cat([generated, next_token], dim=-1)
                
                # Break if EOS token
                if next_token.item() == tokenizer.eos_token_id:
                    break
        
        return generated

# ----------------------
# Load Model from Hugging Face Hub (Updated)
# ----------------------
def load_model_from_hub(repo_id="liminerity/MoR-v1"):
    # 1. Download config
    config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
    with open(config_path, "r") as f:
        config = json.load(f)
    
    print("Model Config:", config)
    
    # 2. Initialize model with config (including max_position_embeddings)
    model = QuantizedMoRModel(
        vocab_size=config["vocab_size"],
        dim=config["dim"],
        num_layers=config["num_layers"],
        num_heads=config["num_heads"],
        max_recursion=config["max_recursion"],
        num_experts=config["num_experts"],
        max_position_embeddings=config["max_position_embeddings"]
    )
    
    # 3. Download and load weights
    weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
    weights = load_file(weights_path)
    model.load_state_dict(weights)
    
    # 4. Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(repo_id)
    
    return model, tokenizer

# ----------------------
# Run Inference
# ----------------------
def run_inference(model, tokenizer, prompt, max_length=100):
    # Encode prompt
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    input_ids = inputs["input_ids"]
    
    # Move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device).eval()
    input_ids = input_ids.to(device)
    
    # Generate text
    output_ids = model.generate(input_ids, max_length=max_length)
    
    # Decode and return
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

# ----------------------
# Main Execution
# ----------------------
if __name__ == "__main__":
    # Load model and tokenizer
    print("Loading model from Hugging Face Hub...")
    model, tokenizer = load_model_from_hub()
    
    # Run inference
    prompt = "The future of artificial intelligence"
    print(f"\nPrompt: {prompt}")
    
    generated_text = run_inference(model, tokenizer, prompt, max_length=100)
    print("\nGenerated Text:")
    print(generated_text)
Downloads last month
3
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support