Upload utils.py with huggingface_hub
Browse files
utils.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def get_clothes_mask(old_label) :
|
| 10 |
+
clothes = torch.FloatTensor((old_label.cpu().numpy() == 3).astype(np.int))
|
| 11 |
+
return clothes
|
| 12 |
+
|
| 13 |
+
def changearm(old_label):
|
| 14 |
+
label=old_label
|
| 15 |
+
arm1=torch.FloatTensor((old_label.cpu().numpy()==5).astype(np.int))
|
| 16 |
+
arm2=torch.FloatTensor((old_label.cpu().numpy()==6).astype(np.int))
|
| 17 |
+
label=label*(1-arm1)+arm1*3
|
| 18 |
+
label=label*(1-arm2)+arm2*3
|
| 19 |
+
return label
|
| 20 |
+
|
| 21 |
+
def gen_noise(shape):
|
| 22 |
+
noise = np.zeros(shape, dtype=np.uint8)
|
| 23 |
+
### noise
|
| 24 |
+
noise = cv2.randn(noise, 0, 255)
|
| 25 |
+
noise = np.asarray(noise / 255, dtype=np.uint8)
|
| 26 |
+
noise = torch.tensor(noise, dtype=torch.float32)
|
| 27 |
+
return noise
|
| 28 |
+
|
| 29 |
+
def cross_entropy2d(input, target, weight=None, size_average=True):
|
| 30 |
+
n, c, h, w = input.size()
|
| 31 |
+
nt, ht, wt = target.size()
|
| 32 |
+
|
| 33 |
+
# Handle inconsistent size between input and target
|
| 34 |
+
if h != ht or w != wt:
|
| 35 |
+
input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
|
| 36 |
+
|
| 37 |
+
input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
|
| 38 |
+
target = target.view(-1)
|
| 39 |
+
loss = F.cross_entropy(
|
| 40 |
+
input, target, weight=weight, size_average=size_average, ignore_index=250
|
| 41 |
+
)
|
| 42 |
+
return loss
|
| 43 |
+
|
| 44 |
+
def ndim_tensor2im(image_tensor, imtype=np.uint8, batch=0):
|
| 45 |
+
image_numpy = image_tensor[batch].cpu().float().numpy()
|
| 46 |
+
result = np.argmax(image_numpy, axis=0)
|
| 47 |
+
return result.astype(imtype)
|
| 48 |
+
|
| 49 |
+
def visualize_segmap(input, multi_channel=True, tensor_out=True, batch=0) :
|
| 50 |
+
palette = [
|
| 51 |
+
0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 85, 0, 169, 0, 51,
|
| 52 |
+
254, 85, 0, 0, 0, 85, 0, 119, 220, 85, 85, 0, 0, 85, 85,
|
| 53 |
+
85, 51, 0, 52, 86, 128, 0, 128, 0, 0, 0, 254, 51, 169, 220,
|
| 54 |
+
0, 254, 254, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0
|
| 55 |
+
]
|
| 56 |
+
input = input.detach()
|
| 57 |
+
if multi_channel :
|
| 58 |
+
input = ndim_tensor2im(input,batch=batch)
|
| 59 |
+
else :
|
| 60 |
+
input = input[batch][0].cpu()
|
| 61 |
+
input = np.asarray(input)
|
| 62 |
+
input = input.astype(np.uint8)
|
| 63 |
+
input = Image.fromarray(input, 'P')
|
| 64 |
+
input.putpalette(palette)
|
| 65 |
+
|
| 66 |
+
if tensor_out :
|
| 67 |
+
trans = transforms.ToTensor()
|
| 68 |
+
return trans(input.convert('RGB'))
|
| 69 |
+
|
| 70 |
+
return input
|
| 71 |
+
|
| 72 |
+
def pred_to_onehot(prediction) :
|
| 73 |
+
size = prediction.shape
|
| 74 |
+
prediction_max = torch.argmax(prediction, dim=1)
|
| 75 |
+
oneHot_size = (size[0], 13, size[2], size[3])
|
| 76 |
+
pred_onehot = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
|
| 77 |
+
pred_onehot = pred_onehot.scatter_(1, prediction_max.unsqueeze(1).data.long(), 1.0)
|
| 78 |
+
return pred_onehot
|
| 79 |
+
|
| 80 |
+
def cal_miou(prediction, target) :
|
| 81 |
+
size = prediction.shape
|
| 82 |
+
target = target.cpu()
|
| 83 |
+
prediction = pred_to_onehot(prediction.detach().cpu())
|
| 84 |
+
list = [1,2,3,4,5,6,7,8]
|
| 85 |
+
union = 0
|
| 86 |
+
intersection = 0
|
| 87 |
+
for b in range(size[0]) :
|
| 88 |
+
for c in list :
|
| 89 |
+
intersection += torch.logical_and(target[b,c], prediction[b,c]).sum()
|
| 90 |
+
union += torch.logical_or(target[b,c], prediction[b,c]).sum()
|
| 91 |
+
return intersection.item()/union.item()
|
| 92 |
+
|
| 93 |
+
def save_images(img_tensors, img_names, save_dir):
|
| 94 |
+
for img_tensor, img_name in zip(img_tensors, img_names):
|
| 95 |
+
tensor = (img_tensor.clone() + 1) * 0.5 * 255
|
| 96 |
+
tensor = tensor.cpu().clamp(0, 255)
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
array = tensor.numpy().astype('uint8')
|
| 100 |
+
except:
|
| 101 |
+
array = tensor.detach().numpy().astype('uint8')
|
| 102 |
+
|
| 103 |
+
if array.shape[0] == 1:
|
| 104 |
+
array = array.squeeze(0)
|
| 105 |
+
elif array.shape[0] == 3:
|
| 106 |
+
array = array.swapaxes(0, 1).swapaxes(1, 2)
|
| 107 |
+
|
| 108 |
+
im = Image.fromarray(array)
|
| 109 |
+
im.save(os.path.join(save_dir, img_name), format='PNG')
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_network(cls, opt):
|
| 113 |
+
net = cls(opt)
|
| 114 |
+
net.print_network()
|
| 115 |
+
if len(opt.gpu_ids) > 0:
|
| 116 |
+
assert(torch.cuda.is_available())
|
| 117 |
+
net.cuda()
|
| 118 |
+
net.init_weights(opt.init_type, opt.init_variance)
|
| 119 |
+
return net
|