YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
######################################
# INFERENCE MIXTURE OF RECURSIONS V1 #
######################################
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import json
# ----------------------
# MoR Model Components
# ----------------------
class ExpertChoiceRouter(nn.Module):
"""Expert Choice Routing: Experts select top-k tokens"""
def __init__(self, dim, num_experts, k=2):
super().__init__()
self.num_experts = num_experts
self.k = k
self.gate = nn.Linear(dim, num_experts, bias=False)
def forward(self, x):
# x: (batch, seq_len, dim)
scores = self.gate(x) # (batch, seq_len, num_experts)
expert_weights, expert_indices = torch.topk(scores, self.k, dim=-1)
return expert_weights.softmax(dim=-1), expert_indices
class Quantizer4Bit(nn.Module):
"""4-bit Quantization Utilities"""
def __init__(self):
super().__init__()
@staticmethod
def quantize(tensor):
"""Quantize tensor to 4-bit integers"""
scale = tensor.abs().max() / 7.5
scale = torch.clamp(scale, min=1e-8)
quantized = torch.clamp(torch.round(tensor / scale), -8, 7)
return quantized.to(torch.int8), scale
@staticmethod
def dequantize(quantized, scale):
"""Dequantize 4-bit integers to float"""
return quantized.float() * scale
class QuantizedRecursiveTransformerBlock(nn.Module):
"""Recursive Transformer Block with Quantization"""
def __init__(self, dim, num_heads, ffn_expansion=4):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
# Attention layers
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.attn_out = nn.Linear(dim, dim)
# FFN layers
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_expansion * dim),
nn.GELU(),
nn.Linear(ffn_expansion * dim, dim)
)
# Normalization
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
# x: (batch, seq_len, dim)
residual = x
x = self.norm1(x)
# Projections
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# Quantize K and V
k_quant, k_scale = Quantizer4Bit.quantize(k)
v_quant, v_scale = Quantizer4Bit.quantize(v)
# Dequantize for computation
k = Quantizer4Bit.dequantize(k_quant, k_scale)
v = Quantizer4Bit.dequantize(v_quant, v_scale)
# Attention
B, T, _ = q.shape
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
attn_out = (attn @ v).transpose(1, 2).contiguous().view(B, T, self.dim)
attn_out = self.attn_out(attn_out)
# Residual connection
x = residual + attn_out
# FFN
x = x + self.ffn(self.norm2(x))
return x
class RecursionDepthRouter(nn.Module):
"""Lightweight Router for Dynamic Recursion Depth"""
def __init__(self, dim, max_depth=4):
super().__init__()
self.max_depth = max_depth
self.router = nn.Sequential(
nn.Linear(dim, 32),
nn.ReLU(),
nn.Linear(32, max_depth)
)
def forward(self, x):
# x: (batch, seq_len, dim)
router_logits = self.router(x.mean(dim=1)) # (batch, max_depth)
return router_logits.softmax(dim=-1)
class QuantizedMoRModel(nn.Module):
"""Main MoR Architecture"""
def __init__(self, vocab_size, dim, num_layers, num_heads, max_recursion, num_experts, max_position_embeddings):
super().__init__()
self.dim = dim
self.max_recursion = max_recursion
self.num_experts = num_experts
self.max_position_embeddings = max_position_embeddings
# Embedding layers
self.embedding = nn.Embedding(vocab_size, dim)
self.pos_embed = nn.Embedding(max_position_embeddings, dim)
# Initial unique layers
self.init_layers = nn.ModuleList([
QuantizedRecursiveTransformerBlock(dim, num_heads)
for _ in range(2)
])
# Middle-cycle shared layers
self.cycle_depth = 3
self.recursive_blocks = nn.ModuleList([
QuantizedRecursiveTransformerBlock(dim, num_heads)
for _ in range(self.cycle_depth)
])
# Recursion routers
self.recursion_routers = nn.ModuleList([
RecursionDepthRouter(dim, max_depth=max_recursion)
for _ in range(num_layers - 4)
])
# Expert choice routing
self.expert_routers = nn.ModuleList([
ExpertChoiceRouter(dim, num_experts)
for _ in range(max_recursion)
])
# Final unique layers
self.final_layers = nn.ModuleList([
QuantizedRecursiveTransformerBlock(dim, num_heads)
for _ in range(2)
])
# Output head
self.ln_f = nn.LayerNorm(dim)
self.head = nn.Linear(dim, vocab_size, bias=False)
def forward(self, x):
# Embedding
pos = torch.arange(0, x.shape[1], device=x.device)
x = self.embedding(x) + self.pos_embed(pos)
# Initial unique layers
for layer in self.init_layers:
x = layer(x)
# Middle-cycle with recursion
all_x = [x]
batch_size, seq_len, _ = x.shape
for router in self.recursion_routers:
# Get recursion depth probabilities
depth_probs = router(x)
# Sample recursion depth
depth = torch.multinomial(depth_probs, 1).squeeze()
# Process through recursive blocks
for d in range(self.max_recursion):
# Expert routing
expert_weights, expert_indices = self.expert_routers[d](x)
# Create full weight matrix
full_weights = torch.zeros((batch_size, seq_len, self.num_experts),
device=x.device)
full_weights.scatter_(2, expert_indices, expert_weights)
# Process each expert
expert_outputs = []
for expert_idx in range(self.num_experts):
# Get expert mask
expert_mask = full_weights[:, :, expert_idx] > 0
if expert_mask.any():
# Create expert input
expert_x = torch.zeros_like(x)
expert_x[expert_mask] = x[expert_mask]
# Process through block
out = self.recursive_blocks[d % self.cycle_depth](expert_x)
expert_outputs.append(out * full_weights[:, :, expert_idx].unsqueeze(-1))
else:
expert_outputs.append(torch.zeros_like(x))
# Combine expert outputs
x = sum(expert_outputs)
all_x.append(x)
# Combine outputs
x = torch.stack(all_x).mean(dim=0)
# Final unique layers
for layer in self.final_layers:
x = layer(x)
# Output
x = self.ln_f(x)
logits = self.head(x)
return logits
def generate(self, input_ids, max_length=100, temperature=0.8, top_k=50):
"""Simple text generation function"""
device = next(self.parameters()).device
generated = input_ids.clone()
with torch.no_grad():
for _ in range(max_length):
# Use max_position_embeddings instead of SEQ_LEN
inputs = generated[:, -self.max_position_embeddings:] \
if generated.shape[1] > self.max_position_embeddings \
else generated
# Forward pass
logits = self(inputs)[:, -1, :]
# Apply temperature
logits = logits / temperature
# Top-k filtering
if top_k > 0:
top_values, _ = torch.topk(logits, top_k)
min_value = top_values[:, -1]
logits[logits < min_value.unsqueeze(-1)] = -float('Inf')
# Sample next token
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
# Append to sequence
generated = torch.cat([generated, next_token], dim=-1)
# Break if EOS token
if next_token.item() == tokenizer.eos_token_id:
break
return generated
# ----------------------
# Load Model from Hugging Face Hub (Updated)
# ----------------------
def load_model_from_hub(repo_id="liminerity/MoR-v1"):
# 1. Download config
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, "r") as f:
config = json.load(f)
print("Model Config:", config)
# 2. Initialize model with config (including max_position_embeddings)
model = QuantizedMoRModel(
vocab_size=config["vocab_size"],
dim=config["dim"],
num_layers=config["num_layers"],
num_heads=config["num_heads"],
max_recursion=config["max_recursion"],
num_experts=config["num_experts"],
max_position_embeddings=config["max_position_embeddings"]
)
# 3. Download and load weights
weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
weights = load_file(weights_path)
model.load_state_dict(weights)
# 4. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(repo_id)
return model, tokenizer
# ----------------------
# Run Inference
# ----------------------
def run_inference(model, tokenizer, prompt, max_length=100):
# Encode prompt
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
input_ids = inputs["input_ids"]
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
input_ids = input_ids.to(device)
# Generate text
output_ids = model.generate(input_ids, max_length=max_length)
# Decode and return
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
# ----------------------
# Main Execution
# ----------------------
if __name__ == "__main__":
# Load model and tokenizer
print("Loading model from Hugging Face Hub...")
model, tokenizer = load_model_from_hub()
# Run inference
prompt = "The future of artificial intelligence"
print(f"\nPrompt: {prompt}")
generated_text = run_inference(model, tokenizer, prompt, max_length=100)
print("\nGenerated Text:")
print(generated_text)
- Downloads last month
- 3
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support