Spaces:
Paused
Paused
jiang.2880
commited on
Commit
·
30ac978
1
Parent(s):
f524781
5tasks
Browse files- app.py +51 -17
- ckpt/epoch=71-step=228600.ckpt +3 -0
- requirements.txt +1 -2
app.py
CHANGED
@@ -12,8 +12,50 @@ from net.mask_spatial_mixer_channel_easyshallow_harddeep_twomore_prompt import S
|
|
12 |
from einops import rearrange
|
13 |
import spaces
|
14 |
|
15 |
-
from huggingface_hub import PyTorchModelHubMixin
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
# crop an image to the multiple of base
|
@@ -23,20 +65,6 @@ def crop_img(image, base=64):
|
|
23 |
crop_h = h % base
|
24 |
crop_w = w % base
|
25 |
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
|
26 |
-
|
27 |
-
class CatAIR5tasksModel(
|
28 |
-
nn.Module,
|
29 |
-
PyTorchModelHubMixin,
|
30 |
-
repo_url="https://huggingface.co/jiachenjiang/Cat-AIR-5tasks",
|
31 |
-
pipeline_tag="image-to-image",
|
32 |
-
license="mit",
|
33 |
-
):
|
34 |
-
def __init__(self, **kwargs):
|
35 |
-
super().__init__()
|
36 |
-
self.net = SpatialMixerChannelEasyShallowHardDeepTwoMorePrompt(**kwargs)
|
37 |
-
|
38 |
-
def forward(self, x):
|
39 |
-
return self.net(x)
|
40 |
|
41 |
def np_to_pil(img_np):
|
42 |
"""
|
@@ -73,8 +101,14 @@ def restore_image(input_img):
|
|
73 |
np.random.seed(0)
|
74 |
torch.manual_seed(0)
|
75 |
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
|
80 |
degraded_img = crop_img(np.array(input_img.convert('RGB')), base=16)
|
|
|
12 |
from einops import rearrange
|
13 |
import spaces
|
14 |
|
|
|
15 |
|
16 |
+
class SpatialMixerChannelEasyShallowHardDeepTwoMorePromptIRModel(pl.LightningModule):
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.net = SpatialMixerChannelEasyShallowHardDeepTwoMorePrompt(
|
21 |
+
inp_channels=3,
|
22 |
+
out_channels=3,
|
23 |
+
dim = 48,
|
24 |
+
num_blocks = [2,4,4,4],
|
25 |
+
num_refinement_blocks = 4,
|
26 |
+
channel_heads = [1,1,1,1],
|
27 |
+
spatial_heads = [1,2,4,8],
|
28 |
+
overlap_ratio = 0.5,
|
29 |
+
dim_head = 16,
|
30 |
+
ratio = 0.5,
|
31 |
+
window_size = 8,
|
32 |
+
bias = False,
|
33 |
+
ffn_expansion_factor = 2.66,
|
34 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
35 |
+
dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
36 |
+
scale = 1,
|
37 |
+
prompt = True,
|
38 |
+
hard_ratio = 0.5
|
39 |
+
)
|
40 |
+
self.loss_fn = nn.L1Loss()
|
41 |
+
|
42 |
+
# Initialize EMA
|
43 |
+
self.ema_decay = 0.999
|
44 |
+
self.ema_model = self.create_ema_model()
|
45 |
+
|
46 |
+
def create_ema_model(self):
|
47 |
+
ema_model = copy.deepcopy(self.net)
|
48 |
+
for param in ema_model.parameters():
|
49 |
+
param.detach_()
|
50 |
+
return ema_model
|
51 |
+
|
52 |
+
def update_ema_model(self):
|
53 |
+
with torch.no_grad():
|
54 |
+
for ema_param, model_param in zip(self.ema_model.parameters(), self.net.parameters()):
|
55 |
+
ema_param.data = self.ema_decay * ema_param.data + (1 - self.ema_decay) * model_param.data
|
56 |
+
|
57 |
+
def forward(self,x, training = False):
|
58 |
+
return self.net(x, training = training)
|
59 |
|
60 |
|
61 |
# crop an image to the multiple of base
|
|
|
65 |
crop_h = h % base
|
66 |
crop_w = w % base
|
67 |
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
def np_to_pil(img_np):
|
70 |
"""
|
|
|
101 |
np.random.seed(0)
|
102 |
torch.manual_seed(0)
|
103 |
|
104 |
+
|
105 |
+
ckpt_path = "/local/storage1/jiang.2880/Cat-AIR/ckpt/epoch=71-step=228600.ckpt"
|
106 |
+
print("CKPT name : {}".format(ckpt_path))
|
107 |
+
|
108 |
+
checkpoint = torch.load(ckpt_path, map_location='cuda', weights_only=True)
|
109 |
+
model = SpatialMixerChannelEasyShallowHardDeepTwoMorePromptIRModel()
|
110 |
+
model.net.load_state_dict(checkpoint['ema_model_state_dict'])
|
111 |
+
net = model.net.cuda()
|
112 |
|
113 |
|
114 |
degraded_img = crop_img(np.array(input_img.convert('RGB')), base=16)
|
ckpt/epoch=71-step=228600.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1586ac517148dc24dd1588af6aa11205f8628253c4b7719211bd7af6568fb325
|
3 |
+
size 535261296
|
requirements.txt
CHANGED
@@ -7,5 +7,4 @@ pytorch-lightning
|
|
7 |
pydantic
|
8 |
pillow
|
9 |
einops
|
10 |
-
numpy<2
|
11 |
-
huggingface_hub
|
|
|
7 |
pydantic
|
8 |
pillow
|
9 |
einops
|
10 |
+
numpy<2
|
|