|
import math |
|
import copy |
|
import random |
|
import torch |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from safetensors.torch import save_file |
|
from safetensors import safe_open |
|
from torch.nn.parameter import Parameter |
|
|
|
from depth_anything_v2_metric.depth_anything_v2.dpt import DepthAnythingV2 |
|
|
|
class _LoRA_qkv(nn.Module): |
|
"""In Sam it is implemented as |
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
B, N, C = x.shape |
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
|
q, k, v = qkv.unbind(0) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
qkv: nn.Module, |
|
linear_a_q: nn.Module, |
|
linear_b_q: nn.Module, |
|
linear_a_v: nn.Module, |
|
linear_b_v: nn.Module, |
|
): |
|
super().__init__() |
|
self.qkv = qkv |
|
self.linear_a_q = linear_a_q |
|
self.linear_b_q = linear_b_q |
|
self.linear_a_v = linear_a_v |
|
self.linear_b_v = linear_b_v |
|
self.dim = qkv.in_features |
|
self.w_identity = torch.eye(qkv.in_features) |
|
|
|
def forward(self, x): |
|
qkv = self.qkv(x) |
|
new_q = self.linear_b_q(self.linear_a_q(x)) |
|
new_v = self.linear_b_v(self.linear_a_v(x)) |
|
|
|
qkv[:, :, : self.dim] += new_q |
|
qkv[:, :, -self.dim:] += new_v |
|
return qkv |
|
|
|
class LoRA(nn.Module): |
|
def __init__(self, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
def save_fc_parameters(self, filename: str) -> None: |
|
r"""Only safetensors is supported now. |
|
|
|
pip install safetensor if you do not have one installed yet. |
|
""" |
|
assert filename.endswith(".safetensors") |
|
_in = self.lora_vit.head.in_features |
|
_out = self.lora_vit.head.out_features |
|
fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight} |
|
save_file(fc_tensors, filename) |
|
|
|
def load_fc_parameters(self, filename: str) -> None: |
|
r"""Only safetensors is supported now. |
|
|
|
pip install safetensor if you do not have one installed yet. |
|
""" |
|
|
|
assert filename.endswith(".safetensors") |
|
_in = self.lora_vit.head.in_features |
|
_out = self.lora_vit.head.out_features |
|
with safe_open(filename, framework="pt") as f: |
|
saved_key = f"fc_{_in}in_{_out}out" |
|
try: |
|
saved_tensor = f.get_tensor(saved_key) |
|
self.lora_vit.head.weight = Parameter(saved_tensor) |
|
except ValueError: |
|
print("this fc weight is not for this model") |
|
|
|
def save_lora_parameters(self, filename: str) -> None: |
|
r"""Only safetensors is supported now. |
|
|
|
pip install safetensor if you do not have one installed yet. |
|
|
|
save both lora and fc parameters. |
|
""" |
|
|
|
assert filename.endswith(".safetensors") |
|
|
|
num_layer = len(self.w_As) |
|
a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)} |
|
b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)} |
|
|
|
_in = self.lora_vit.head.in_features |
|
_out = self.lora_vit.head.out_features |
|
fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight} |
|
|
|
merged_dict = {**a_tensors, **b_tensors, **fc_tensors} |
|
save_file(merged_dict, filename) |
|
|
|
def load_lora_parameters(self, filename: str) -> None: |
|
r"""Only safetensors is supported now. |
|
|
|
pip install safetensor if you do not have one installed yet.\ |
|
|
|
load both lora and fc parameters. |
|
""" |
|
|
|
assert filename.endswith(".safetensors") |
|
|
|
with safe_open(filename, framework="pt") as f: |
|
for i, w_A_linear in enumerate(self.w_As): |
|
saved_key = f"w_a_{i:03d}" |
|
saved_tensor = f.get_tensor(saved_key) |
|
w_A_linear.weight = Parameter(saved_tensor) |
|
|
|
for i, w_B_linear in enumerate(self.w_Bs): |
|
saved_key = f"w_b_{i:03d}" |
|
saved_tensor = f.get_tensor(saved_key) |
|
w_B_linear.weight = Parameter(saved_tensor) |
|
|
|
_in = self.lora_vit.head.in_features |
|
_out = self.lora_vit.head.out_features |
|
saved_key = f"fc_{_in}in_{_out}out" |
|
try: |
|
saved_tensor = f.get_tensor(saved_key) |
|
self.lora_vit.head.weight = Parameter(saved_tensor) |
|
except ValueError: |
|
print("this fc weight is not for this model") |
|
|
|
def reset_parameters(self) -> None: |
|
for w_A in self.w_As: |
|
nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5)) |
|
for w_B in self.w_Bs: |
|
nn.init.zeros_(w_B.weight) |
|
|
|
class LoRA_Depth_Anything_v2(LoRA): |
|
"""Applies low-rank adaptation to a Depth Anything model's image encoder. |
|
|
|
Args: |
|
sam_model: a vision transformer model, see base_vit.py |
|
r: rank of LoRA |
|
num_classes: how many classes the model output, default to the vit model |
|
lora_layer: which layer we apply LoRA. |
|
|
|
Examples:: |
|
>>> model = ViT('B_16_imagenet1k') |
|
>>> lora_model = LoRA_ViT(model, r=4) |
|
>>> preds = lora_model(img) |
|
>>> print(preds.shape) |
|
torch.Size([1, 1000]) |
|
""" |
|
|
|
def __init__(self, da_model: DepthAnythingV2, r: int, lora_layer=None): |
|
super(LoRA_Depth_Anything_v2, self).__init__() |
|
|
|
assert r > 0 |
|
|
|
|
|
if lora_layer: |
|
self.lora_layer = lora_layer |
|
else: |
|
self.lora_layer = list(range(len(da_model.pretrained.blocks))) |
|
|
|
self.w_As = [] |
|
self.w_Bs = [] |
|
|
|
|
|
for param in da_model.pretrained.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
for t_layer_i, blk in enumerate(da_model.pretrained.blocks): |
|
|
|
if t_layer_i not in self.lora_layer: |
|
continue |
|
w_qkv_linear = blk.attn.qkv |
|
self.dim = w_qkv_linear.in_features |
|
w_a_linear_q = nn.Linear(self.dim, r, bias=False) |
|
w_b_linear_q = nn.Linear(r, self.dim, bias=False) |
|
w_a_linear_v = nn.Linear(self.dim, r, bias=False) |
|
w_b_linear_v = nn.Linear(r, self.dim, bias=False) |
|
self.w_As.append(w_a_linear_q) |
|
self.w_Bs.append(w_b_linear_q) |
|
self.w_As.append(w_a_linear_v) |
|
self.w_Bs.append(w_b_linear_v) |
|
blk.attn.qkv = _LoRA_qkv( |
|
w_qkv_linear, |
|
w_a_linear_q, |
|
w_b_linear_q, |
|
w_a_linear_v, |
|
w_b_linear_v, |
|
) |
|
self.reset_parameters() |
|
|
|
self.lora_vit = da_model |