""" Mortality Graph Construction for THRML Integration ================================================= This module converts Morbid AI's MortalityRecord data structure into THRML-compatible probabilistic graphical models. """ import jax import jax.numpy as jnp import networkx as nx from typing import List, Dict, Tuple, Optional import pandas as pd from dataclasses import dataclass from thrml.pgm import CategoricalNode, SpinNode from thrml.block_management import Block from thrml.factor import AbstractFactor @dataclass class MortalityRecord: """Morbid AI mortality record structure""" country: str year: int sex: int # 1=male, 2=female, 3=both age: int deathRate: float # m(x) deathProbability: float # q(x) survivors: float # l(x) deaths: float # d(x) lifeExpectancy: float # e(x) class MortalityGraphBuilder: """ Builds THRML-compatible probabilistic graphical models from mortality data. This class creates heterogeneous graphs that capture complex interactions between demographic factors (age, country, sex, year) and mortality outcomes. """ def __init__(self, mortality_data: List[MortalityRecord]): """ Initialize with mortality data. Args: mortality_data: List of MortalityRecord objects """ self.mortality_data = mortality_data self.df = pd.DataFrame([ { 'country': record.country, 'year': record.year, 'sex': record.sex, 'age': record.age, 'death_rate': record.deathRate, 'death_probability': record.deathProbability, 'survivors': record.survivors, 'deaths': record.deaths, 'life_expectancy': record.lifeExpectancy } for record in mortality_data ]) # Extract unique values for graph construction self.countries = sorted(self.df['country'].unique()) self.years = sorted(self.df['year'].unique()) self.sexes = sorted(self.df['sex'].unique()) self.ages = sorted(self.df['age'].unique()) # Create node mappings self._create_node_mappings() def _create_node_mappings(self): """Create mappings from data values to graph nodes.""" self.country_nodes = {country: CategoricalNode() for country in self.countries} self.year_nodes = {year: CategoricalNode() for year in self.years} self.sex_nodes = {sex: SpinNode() for sex in self.sexes} # Binary-like representation self.age_nodes = {age: CategoricalNode() for age in self.ages} # Create outcome nodes for mortality metrics self.life_expectancy_nodes = {} self.death_probability_nodes = {} # Discretize life expectancy and death probability for categorical representation self.life_exp_bins = jnp.linspace(0, 100, 21) # 20 bins for life expectancy self.death_prob_bins = jnp.linspace(0, 1, 11) # 10 bins for death probability for i in range(len(self.life_exp_bins) - 1): self.life_expectancy_nodes[i] = CategoricalNode() for i in range(len(self.death_prob_bins) - 1): self.death_probability_nodes[i] = CategoricalNode() def build_mortality_graph(self) -> nx.Graph: """ Build NetworkX graph representing mortality factor interactions. Returns: NetworkX graph with nodes representing demographic factors and edges representing interactions. """ G = nx.Graph() # Add nodes with attributes for country, node in self.country_nodes.items(): G.add_node(f"country_{country}", type="country", value=country, thrml_node=node) for year, node in self.year_nodes.items(): G.add_node(f"year_{year}", type="year", value=year, thrml_node=node) for sex, node in self.sex_nodes.items(): G.add_node(f"sex_{sex}", type="sex", value=sex, thrml_node=node) for age, node in self.age_nodes.items(): G.add_node(f"age_{age}", type="age", value=age, thrml_node=node) # Add outcome nodes for bin_idx, node in self.life_expectancy_nodes.items(): G.add_node(f"life_exp_{bin_idx}", type="life_expectancy", bin_idx=bin_idx, thrml_node=node) for bin_idx, node in self.death_probability_nodes.items(): G.add_node(f"death_prob_{bin_idx}", type="death_probability", bin_idx=bin_idx, thrml_node=node) # Add edges representing factor interactions self._add_demographic_interactions(G) self._add_outcome_interactions(G) return G def _add_demographic_interactions(self, G: nx.Graph): """Add edges between demographic factor nodes.""" # Age-Sex interactions (biological mortality differences) for age in self.ages: for sex in self.sexes: G.add_edge(f"age_{age}", f"sex_{sex}", interaction_type="age_sex") # Country-Year interactions (temporal mortality trends by country) for country in self.countries: for year in self.years: G.add_edge(f"country_{country}", f"year_{year}", interaction_type="country_year") # Age-Country interactions (demographic mortality patterns) for age in self.ages[::5]: # Sample every 5th age to reduce complexity for country in self.countries: G.add_edge(f"age_{age}", f"country_{country}", interaction_type="age_country") def _add_outcome_interactions(self, G: nx.Graph): """Add edges between demographic factors and mortality outcomes.""" # Connect age groups to life expectancy bins for age in self.ages[::10]: # Sample to reduce complexity for le_bin in range(len(self.life_expectancy_nodes)): G.add_edge(f"age_{age}", f"life_exp_{le_bin}", interaction_type="age_life_expectancy") # Connect demographic factors to death probability for country in self.countries: for dp_bin in range(len(self.death_probability_nodes)): G.add_edge(f"country_{country}", f"death_prob_{dp_bin}", interaction_type="country_death_probability") def create_sampling_blocks(self, strategy: str = "two_color") -> List[Block]: """ Create sampling blocks for THRML block Gibbs sampling. Args: strategy: Blocking strategy ("two_color", "demographic", "outcome") Returns: List of Block objects for THRML sampling """ all_nodes = [] # Collect all THRML nodes all_nodes.extend(list(self.country_nodes.values())) all_nodes.extend(list(self.year_nodes.values())) all_nodes.extend(list(self.sex_nodes.values())) all_nodes.extend(list(self.age_nodes.values())) all_nodes.extend(list(self.life_expectancy_nodes.values())) all_nodes.extend(list(self.death_probability_nodes.values())) if strategy == "two_color": # Simple two-color blocking with homogeneous node types categorical_nodes = (list(self.country_nodes.values()) + list(self.year_nodes.values()) + list(self.age_nodes.values()) + list(self.life_expectancy_nodes.values()) + list(self.death_probability_nodes.values())) spin_nodes = list(self.sex_nodes.values()) # Create separate blocks for different node types if categorical_nodes and spin_nodes: return [Block(categorical_nodes), Block(spin_nodes)] elif categorical_nodes: return [Block(categorical_nodes[::2]), Block(categorical_nodes[1::2])] else: return [Block(spin_nodes)] elif strategy == "demographic": # Block by demographic factor types - separate by node type categorical_demographic = (list(self.country_nodes.values()) + list(self.year_nodes.values()) + list(self.age_nodes.values())) spin_demographic = list(self.sex_nodes.values()) outcome_nodes = (list(self.life_expectancy_nodes.values()) + list(self.death_probability_nodes.values())) blocks = [] if categorical_demographic: blocks.append(Block(categorical_demographic)) if spin_demographic: blocks.append(Block(spin_demographic)) if outcome_nodes: blocks.append(Block(outcome_nodes)) return blocks elif strategy == "outcome": # Block by outcome type - keep node types separate life_exp_nodes = list(self.life_expectancy_nodes.values()) death_prob_nodes = list(self.death_probability_nodes.values()) categorical_demo = (list(self.country_nodes.values()) + list(self.year_nodes.values()) + list(self.age_nodes.values())) spin_demo = list(self.sex_nodes.values()) blocks = [] if life_exp_nodes: blocks.append(Block(life_exp_nodes)) if death_prob_nodes: blocks.append(Block(death_prob_nodes)) if categorical_demo: blocks.append(Block(categorical_demo)) if spin_demo: blocks.append(Block(spin_demo)) return blocks else: raise ValueError(f"Unknown blocking strategy: {strategy}") def create_interaction_factors(self) -> List[Dict]: """ Create interaction factors for the energy-based model. Returns: List of simplified factor objects representing pairwise and higher-order interactions """ factors = [] # Age-Sex interaction factors for age in self.ages[::10]: # Sample to manage complexity for sex in self.sexes: age_node = self.age_nodes[age] sex_node = self.sex_nodes[sex] # Create interaction matrix based on mortality data interaction_strength = self._compute_age_sex_interaction(age, sex) # For now, create a simplified factor representation # In a full implementation, would create proper THRML factors factors.append({ 'nodes': [age_node, sex_node], 'strength': interaction_strength, 'type': 'age_sex' }) # Country-Year interaction factors for country in self.countries: for year in self.years[::2]: # Sample years country_node = self.country_nodes[country] year_node = self.year_nodes[year] interaction_strength = self._compute_country_year_interaction(country, year) # Simplified factor representation factors.append({ 'nodes': [country_node, year_node], 'strength': interaction_strength, 'type': 'country_year' }) return factors def _compute_age_sex_interaction(self, age: int, sex: int) -> jnp.ndarray: """Compute interaction strength between age and sex from data.""" # Filter data for this age-sex combination subset = self.df[(self.df['age'] == age) & (self.df['sex'] == sex)] if len(subset) == 0: # Default weak interaction if no data return jnp.array([[0.1, 0.0], [0.0, 0.1]]) # Use death rate as proxy for interaction strength avg_death_rate = subset['death_rate'].mean() # Create 2x2 interaction matrix # Higher death rates = stronger interaction strength = min(avg_death_rate * 10, 1.0) # Cap at 1.0 return jnp.array([[strength, -strength/2], [-strength/2, strength]]) def _compute_country_year_interaction(self, country: str, year: int) -> jnp.ndarray: """Compute interaction strength between country and year.""" subset = self.df[(self.df['country'] == country) & (self.df['year'] == year)] if len(subset) == 0: return jnp.array([[0.1, 0.0], [0.0, 0.1]]) # Use life expectancy variance as interaction strength life_exp_var = subset['life_expectancy'].var() strength = min(life_exp_var / 100, 1.0) # Normalize and cap return jnp.array([[strength, -strength/3], [-strength/3, strength]]) def get_mortality_prediction_nodes(self, age: int, country: str, sex: int) -> Dict[str, any]: """ Get the relevant nodes for mortality prediction given demographics. Args: age: Age value country: Country name sex: Sex value (1=male, 2=female, 3=both) Returns: Dictionary mapping node types to THRML nodes """ return { 'age_node': self.age_nodes.get(age), 'country_node': self.country_nodes.get(country), 'sex_node': self.sex_nodes.get(sex), 'life_expectancy_nodes': self.life_expectancy_nodes, 'death_probability_nodes': self.death_probability_nodes } def discretize_life_expectancy(self, life_exp: float) -> int: """Convert continuous life expectancy to discrete bin index.""" return int(jnp.digitize(life_exp, self.life_exp_bins)) - 1 def discretize_death_probability(self, death_prob: float) -> int: """Convert continuous death probability to discrete bin index.""" return int(jnp.digitize(death_prob, self.death_prob_bins)) - 1 def continuous_from_bin(self, bin_idx: int, bin_type: str) -> float: """Convert bin index back to continuous value (bin center).""" if bin_type == "life_expectancy": if 0 <= bin_idx < len(self.life_exp_bins) - 1: return (self.life_exp_bins[bin_idx] + self.life_exp_bins[bin_idx + 1]) / 2 elif bin_type == "death_probability": if 0 <= bin_idx < len(self.death_prob_bins) - 1: return (self.death_prob_bins[bin_idx] + self.death_prob_bins[bin_idx + 1]) / 2 return 0.0