Ubaida10 commited on
Commit
8c1552d
·
verified ·
1 Parent(s): 0f0b64c

Upload utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. utils.py +119 -0
utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import cv2
7
+ import os
8
+
9
+ def get_clothes_mask(old_label) :
10
+ clothes = torch.FloatTensor((old_label.cpu().numpy() == 3).astype(np.int))
11
+ return clothes
12
+
13
+ def changearm(old_label):
14
+ label=old_label
15
+ arm1=torch.FloatTensor((old_label.cpu().numpy()==5).astype(np.int))
16
+ arm2=torch.FloatTensor((old_label.cpu().numpy()==6).astype(np.int))
17
+ label=label*(1-arm1)+arm1*3
18
+ label=label*(1-arm2)+arm2*3
19
+ return label
20
+
21
+ def gen_noise(shape):
22
+ noise = np.zeros(shape, dtype=np.uint8)
23
+ ### noise
24
+ noise = cv2.randn(noise, 0, 255)
25
+ noise = np.asarray(noise / 255, dtype=np.uint8)
26
+ noise = torch.tensor(noise, dtype=torch.float32)
27
+ return noise
28
+
29
+ def cross_entropy2d(input, target, weight=None, size_average=True):
30
+ n, c, h, w = input.size()
31
+ nt, ht, wt = target.size()
32
+
33
+ # Handle inconsistent size between input and target
34
+ if h != ht or w != wt:
35
+ input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
36
+
37
+ input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
38
+ target = target.view(-1)
39
+ loss = F.cross_entropy(
40
+ input, target, weight=weight, size_average=size_average, ignore_index=250
41
+ )
42
+ return loss
43
+
44
+ def ndim_tensor2im(image_tensor, imtype=np.uint8, batch=0):
45
+ image_numpy = image_tensor[batch].cpu().float().numpy()
46
+ result = np.argmax(image_numpy, axis=0)
47
+ return result.astype(imtype)
48
+
49
+ def visualize_segmap(input, multi_channel=True, tensor_out=True, batch=0) :
50
+ palette = [
51
+ 0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 85, 0, 169, 0, 51,
52
+ 254, 85, 0, 0, 0, 85, 0, 119, 220, 85, 85, 0, 0, 85, 85,
53
+ 85, 51, 0, 52, 86, 128, 0, 128, 0, 0, 0, 254, 51, 169, 220,
54
+ 0, 254, 254, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0
55
+ ]
56
+ input = input.detach()
57
+ if multi_channel :
58
+ input = ndim_tensor2im(input,batch=batch)
59
+ else :
60
+ input = input[batch][0].cpu()
61
+ input = np.asarray(input)
62
+ input = input.astype(np.uint8)
63
+ input = Image.fromarray(input, 'P')
64
+ input.putpalette(palette)
65
+
66
+ if tensor_out :
67
+ trans = transforms.ToTensor()
68
+ return trans(input.convert('RGB'))
69
+
70
+ return input
71
+
72
+ def pred_to_onehot(prediction) :
73
+ size = prediction.shape
74
+ prediction_max = torch.argmax(prediction, dim=1)
75
+ oneHot_size = (size[0], 13, size[2], size[3])
76
+ pred_onehot = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
77
+ pred_onehot = pred_onehot.scatter_(1, prediction_max.unsqueeze(1).data.long(), 1.0)
78
+ return pred_onehot
79
+
80
+ def cal_miou(prediction, target) :
81
+ size = prediction.shape
82
+ target = target.cpu()
83
+ prediction = pred_to_onehot(prediction.detach().cpu())
84
+ list = [1,2,3,4,5,6,7,8]
85
+ union = 0
86
+ intersection = 0
87
+ for b in range(size[0]) :
88
+ for c in list :
89
+ intersection += torch.logical_and(target[b,c], prediction[b,c]).sum()
90
+ union += torch.logical_or(target[b,c], prediction[b,c]).sum()
91
+ return intersection.item()/union.item()
92
+
93
+ def save_images(img_tensors, img_names, save_dir):
94
+ for img_tensor, img_name in zip(img_tensors, img_names):
95
+ tensor = (img_tensor.clone() + 1) * 0.5 * 255
96
+ tensor = tensor.cpu().clamp(0, 255)
97
+
98
+ try:
99
+ array = tensor.numpy().astype('uint8')
100
+ except:
101
+ array = tensor.detach().numpy().astype('uint8')
102
+
103
+ if array.shape[0] == 1:
104
+ array = array.squeeze(0)
105
+ elif array.shape[0] == 3:
106
+ array = array.swapaxes(0, 1).swapaxes(1, 2)
107
+
108
+ im = Image.fromarray(array)
109
+ im.save(os.path.join(save_dir, img_name), format='PNG')
110
+
111
+
112
+ def create_network(cls, opt):
113
+ net = cls(opt)
114
+ net.print_network()
115
+ if len(opt.gpu_ids) > 0:
116
+ assert(torch.cuda.is_available())
117
+ net.cuda()
118
+ net.init_weights(opt.init_type, opt.init_variance)
119
+ return net