jiang.2880 commited on
Commit
30ac978
·
1 Parent(s): f524781
Files changed (3) hide show
  1. app.py +51 -17
  2. ckpt/epoch=71-step=228600.ckpt +3 -0
  3. requirements.txt +1 -2
app.py CHANGED
@@ -12,8 +12,50 @@ from net.mask_spatial_mixer_channel_easyshallow_harddeep_twomore_prompt import S
12
  from einops import rearrange
13
  import spaces
14
 
15
- from huggingface_hub import PyTorchModelHubMixin
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  # crop an image to the multiple of base
@@ -23,20 +65,6 @@ def crop_img(image, base=64):
23
  crop_h = h % base
24
  crop_w = w % base
25
  return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
26
-
27
- class CatAIR5tasksModel(
28
- nn.Module,
29
- PyTorchModelHubMixin,
30
- repo_url="https://huggingface.co/jiachenjiang/Cat-AIR-5tasks",
31
- pipeline_tag="image-to-image",
32
- license="mit",
33
- ):
34
- def __init__(self, **kwargs):
35
- super().__init__()
36
- self.net = SpatialMixerChannelEasyShallowHardDeepTwoMorePrompt(**kwargs)
37
-
38
- def forward(self, x):
39
- return self.net(x)
40
 
41
  def np_to_pil(img_np):
42
  """
@@ -73,8 +101,14 @@ def restore_image(input_img):
73
  np.random.seed(0)
74
  torch.manual_seed(0)
75
 
76
- model = CatAIR5tasksModel.from_pretrained("jiachenjiang/Cat-AIR-5tasks").cuda()
77
- net = model.net.eval()
 
 
 
 
 
 
78
 
79
 
80
  degraded_img = crop_img(np.array(input_img.convert('RGB')), base=16)
 
12
  from einops import rearrange
13
  import spaces
14
 
 
15
 
16
+ class SpatialMixerChannelEasyShallowHardDeepTwoMorePromptIRModel(pl.LightningModule):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ self.net = SpatialMixerChannelEasyShallowHardDeepTwoMorePrompt(
21
+ inp_channels=3,
22
+ out_channels=3,
23
+ dim = 48,
24
+ num_blocks = [2,4,4,4],
25
+ num_refinement_blocks = 4,
26
+ channel_heads = [1,1,1,1],
27
+ spatial_heads = [1,2,4,8],
28
+ overlap_ratio = 0.5,
29
+ dim_head = 16,
30
+ ratio = 0.5,
31
+ window_size = 8,
32
+ bias = False,
33
+ ffn_expansion_factor = 2.66,
34
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
35
+ dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
36
+ scale = 1,
37
+ prompt = True,
38
+ hard_ratio = 0.5
39
+ )
40
+ self.loss_fn = nn.L1Loss()
41
+
42
+ # Initialize EMA
43
+ self.ema_decay = 0.999
44
+ self.ema_model = self.create_ema_model()
45
+
46
+ def create_ema_model(self):
47
+ ema_model = copy.deepcopy(self.net)
48
+ for param in ema_model.parameters():
49
+ param.detach_()
50
+ return ema_model
51
+
52
+ def update_ema_model(self):
53
+ with torch.no_grad():
54
+ for ema_param, model_param in zip(self.ema_model.parameters(), self.net.parameters()):
55
+ ema_param.data = self.ema_decay * ema_param.data + (1 - self.ema_decay) * model_param.data
56
+
57
+ def forward(self,x, training = False):
58
+ return self.net(x, training = training)
59
 
60
 
61
  # crop an image to the multiple of base
 
65
  crop_h = h % base
66
  crop_w = w % base
67
  return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def np_to_pil(img_np):
70
  """
 
101
  np.random.seed(0)
102
  torch.manual_seed(0)
103
 
104
+
105
+ ckpt_path = "/local/storage1/jiang.2880/Cat-AIR/ckpt/epoch=71-step=228600.ckpt"
106
+ print("CKPT name : {}".format(ckpt_path))
107
+
108
+ checkpoint = torch.load(ckpt_path, map_location='cuda', weights_only=True)
109
+ model = SpatialMixerChannelEasyShallowHardDeepTwoMorePromptIRModel()
110
+ model.net.load_state_dict(checkpoint['ema_model_state_dict'])
111
+ net = model.net.cuda()
112
 
113
 
114
  degraded_img = crop_img(np.array(input_img.convert('RGB')), base=16)
ckpt/epoch=71-step=228600.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1586ac517148dc24dd1588af6aa11205f8628253c4b7719211bd7af6568fb325
3
+ size 535261296
requirements.txt CHANGED
@@ -7,5 +7,4 @@ pytorch-lightning
7
  pydantic
8
  pillow
9
  einops
10
- numpy<2
11
- huggingface_hub
 
7
  pydantic
8
  pillow
9
  einops
10
+ numpy<2