import torch import copy import numpy as np from scipy.stats import pearsonr from t_cube import evaluate_model def evaluate_slerp(clip_pt, sd_pt, sd_ft, dataloader, args, alpha=0.5): """ SLERP (spherical linear interpolation) between pretrained (pt) and fine-tuned (ft) weights. alpha=0 -> pt only; alpha=1 -> ft only. """ model = copy.deepcopy(clip_pt) merged_sd = {} # flatten-per-key SLERP for k in sd_pt.keys(): w1 = sd_pt[k].flatten().float() w2 = sd_ft[k].flatten().float() # cosine similarity cos_val = torch.dot(w1, w2) / (w1.norm() * w2.norm() + 1e-8) omega = torch.acos(torch.clamp(cos_val, -1+1e-6, 1-1e-6)) sin_omega = torch.sin(omega) if sin_omega < 1e-6: w_interp = (1-alpha)*w1 + alpha*w2 else: w_interp = (torch.sin((1-alpha)*omega)/sin_omega)*w1 + \ (torch.sin(alpha*omega)/sin_omega)*w2 merged_sd[k] = w_interp.view_as(sd_pt[k]) model.load_state_dict(merged_sd) return evaluate_model(model, dataloader, args) def evaluate_m3(clip_pt, sd_pt, sd_ft, dataloader, args): """ M^3 (Mixup Model Merge): sample lambda ~ Uniform(0,1) and do linear interpolation. """ model = copy.deepcopy(clip_pt) lam = np.random.rand() merged_sd = {k: lam * sd_ft[k] + (1 - lam) * sd_pt[k] for k in sd_pt.keys()} model.load_state_dict(merged_sd) return evaluate_model(model, dataloader, args) def evaluate_task_arithmetic(clip_pt, sd_pt, sd_ft, dataloader, args): """ Task Arithmetic: extrapolate along the ft−pt vector, i.e. 2*ft – pt. """ model = copy.deepcopy(clip_pt) merged_sd = {k: 2 * sd_ft[k] - sd_pt[k] for k in sd_pt.keys()} model.load_state_dict(merged_sd) return evaluate_model(model, dataloader, args)