HichTala commited on
Commit
0916c25
·
1 Parent(s): e6a6ef1

Upload code

Browse files
head.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from dataclasses import astuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn.modules.transformer import _get_activation_fn
8
+ from torchvision.ops import RoIAlign
9
+
10
+ _DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16)
11
+
12
+ def convert_boxes_to_pooler_format(bboxes):
13
+ bs, num_proposals = bboxes.shape[:2]
14
+ sizes = torch.full((bs,), num_proposals).to(bboxes.device)
15
+ aggregated_bboxes = bboxes.view(bs * num_proposals, -1)
16
+ indices = torch.repeat_interleave(
17
+ torch.arange(len(sizes), dtype=aggregated_bboxes.dtype, device=aggregated_bboxes.device), sizes
18
+ )
19
+ return torch.cat([indices[:, None], aggregated_bboxes], dim=1)
20
+
21
+
22
+ def assign_boxes_to_levels(
23
+ bboxes,
24
+ min_level,
25
+ max_level,
26
+ canonical_box_size,
27
+ canonical_level,
28
+ ):
29
+ aggregated_bboxes = bboxes.view(bboxes.shape[0] * bboxes.shape[1], -1)
30
+ area = (aggregated_bboxes[:, 2] - aggregated_bboxes[:, 0]) * (aggregated_bboxes[:, 3] - aggregated_bboxes[:, 1])
31
+ box_sizes = torch.sqrt(area)
32
+ # Eqn.(1) in FPN paper
33
+ level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8))
34
+ # clamp level to (min, max), in case the box size is too large or too small
35
+ # for the available feature maps
36
+ level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
37
+ return level_assignments.to(torch.int64) - min_level
38
+
39
+
40
+ class SinusoidalPositionEmbeddings(nn.Module):
41
+ def __init__(self, dim):
42
+ super().__init__()
43
+ self.dim = dim
44
+
45
+ def forward(self, time):
46
+ device = time.device
47
+ half_dim = self.dim // 2
48
+ embeddings = math.log(10000) / (half_dim - 1)
49
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
50
+ embeddings = time[:, None] * embeddings[None, :]
51
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
52
+ return embeddings
53
+
54
+
55
+ class HeadDynamicK(nn.Module):
56
+ def __init__(self, config, roi_input_shape):
57
+ super().__init__()
58
+ num_classes = config.num_labels
59
+
60
+ ddet_head = DiffusionDetHead(config, roi_input_shape, num_classes)
61
+ self.num_head = config.num_heads
62
+ self.head_series = nn.ModuleList([copy.deepcopy(ddet_head) for _ in range(self.num_head)])
63
+ self.return_intermediate = config.deep_supervision
64
+
65
+ # Gaussian random feature embedding layer for time
66
+ self.hidden_dim = config.hidden_dim
67
+ time_dim = self.hidden_dim * 4
68
+ self.time_mlp = nn.Sequential(
69
+ SinusoidalPositionEmbeddings(self.hidden_dim),
70
+ nn.Linear(self.hidden_dim, time_dim),
71
+ nn.GELU(),
72
+ nn.Linear(time_dim, time_dim),
73
+ )
74
+
75
+ # Init parameters.
76
+ self.use_focal = config.use_focal
77
+ self.use_fed_loss = config.use_fed_loss
78
+ self.num_classes = num_classes
79
+ if self.use_focal or self.use_fed_loss:
80
+ prior_prob = config.prior_prob
81
+ self.bias_value = -math.log((1 - prior_prob) / prior_prob)
82
+ self._reset_parameters()
83
+
84
+ def _reset_parameters(self):
85
+ # init all parameters.
86
+ for p in self.parameters():
87
+ if p.dim() > 1:
88
+ nn.init.xavier_uniform_(p)
89
+
90
+ # initialize the bias for focal loss and fed loss.
91
+ if self.use_focal or self.use_fed_loss:
92
+ if p.shape[-1] == self.num_classes or p.shape[-1] == self.num_classes + 1:
93
+ nn.init.constant_(p, self.bias_value)
94
+
95
+
96
+ def forward(self, features, bboxes, t):
97
+ # assert t shape (batch_size)
98
+ time = self.time_mlp(t)
99
+
100
+ inter_class_logits = []
101
+ inter_pred_bboxes = []
102
+
103
+ bs = len(features[0])
104
+
105
+ class_logits, pred_bboxes = None, None
106
+ for head_idx, ddet_head in enumerate(self.head_series):
107
+ class_logits, pred_bboxes, proposal_features = ddet_head(features, bboxes, time)
108
+ if self.return_intermediate:
109
+ inter_class_logits.append(class_logits)
110
+ inter_pred_bboxes.append(pred_bboxes)
111
+ bboxes = pred_bboxes.detach()
112
+
113
+ if self.return_intermediate:
114
+ return torch.stack(inter_class_logits), torch.stack(inter_pred_bboxes)
115
+
116
+ return class_logits[None], pred_bboxes[None]
117
+
118
+
119
+ class DynamicConv(nn.Module):
120
+ def __init__(self, config):
121
+ super().__init__()
122
+
123
+ self.hidden_dim = config.hidden_dim
124
+ self.dim_dynamic = config.dim_dynamic
125
+ self.num_dynamic = config.num_dynamic
126
+ self.num_params = self.hidden_dim * self.dim_dynamic
127
+ self.dynamic_layer = nn.Linear(self.hidden_dim, self.num_dynamic * self.num_params)
128
+
129
+ self.norm1 = nn.LayerNorm(self.dim_dynamic)
130
+ self.norm2 = nn.LayerNorm(self.hidden_dim)
131
+
132
+ self.activation = nn.ReLU(inplace=True)
133
+
134
+ pooler_resolution = config.pooler_resolution
135
+ num_output = self.hidden_dim * pooler_resolution ** 2
136
+ self.out_layer = nn.Linear(num_output, self.hidden_dim)
137
+ self.norm3 = nn.LayerNorm(self.hidden_dim)
138
+
139
+
140
+ def forward(self, pro_features, roi_features):
141
+ features = roi_features.permute(1, 0, 2)
142
+ parameters = self.dynamic_layer(pro_features).permute(1, 0, 2)
143
+
144
+ param1 = parameters[:, :, :self.num_params].view(-1, self.hidden_dim, self.dim_dynamic)
145
+ param2 = parameters[:, :, self.num_params:].view(-1, self.dim_dynamic, self.hidden_dim)
146
+
147
+ features = torch.bmm(features, param1)
148
+ features = self.norm1(features)
149
+ features = self.activation(features)
150
+
151
+ features = torch.bmm(features, param2)
152
+ features = self.norm2(features)
153
+ features = self.activation(features)
154
+
155
+ features = features.flatten(1)
156
+ features = self.out_layer(features)
157
+ features = self.norm3(features)
158
+ features = self.activation(features)
159
+
160
+ return features
161
+
162
+
163
+ class DiffusionDetHead(nn.Module):
164
+ def __init__(self, config, roi_input_shape, num_classes):
165
+ super().__init__()
166
+
167
+ dim_feedforward = config.dim_feedforward
168
+ nhead = config.num_attn_heads
169
+ dropout = config.dropout
170
+ activation = config.activation
171
+ in_features = config.roi_head_in_features
172
+ pooler_resolution = config.pooler_resolution
173
+ pooler_scales = tuple(1.0 / roi_input_shape[k]['stride'] for k in in_features)
174
+ sampling_ratio = config.sampling_ratio
175
+
176
+ self.hidden_dim = config.hidden_dim
177
+
178
+ self.pooler = ROIPooler(
179
+ output_size=pooler_resolution,
180
+ scales=pooler_scales,
181
+ sampling_ratio=sampling_ratio,
182
+ )
183
+
184
+ # dynamic.
185
+ self.self_attn = nn.MultiheadAttention(self.hidden_dim, nhead, dropout=dropout)
186
+ self.inst_interact = DynamicConv(config)
187
+
188
+ self.linear1 = nn.Linear(self.hidden_dim, dim_feedforward)
189
+ self.dropout = nn.Dropout(dropout)
190
+ self.linear2 = nn.Linear(dim_feedforward, self.hidden_dim)
191
+
192
+ self.norm1 = nn.LayerNorm(self.hidden_dim)
193
+ self.norm2 = nn.LayerNorm(self.hidden_dim)
194
+ self.norm3 = nn.LayerNorm(self.hidden_dim)
195
+ self.dropout1 = nn.Dropout(dropout)
196
+ self.dropout2 = nn.Dropout(dropout)
197
+ self.dropout3 = nn.Dropout(dropout)
198
+
199
+ self.activation = _get_activation_fn(activation)
200
+
201
+ # block time mlp
202
+ self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(self.hidden_dim * 4, self.hidden_dim * 2))
203
+
204
+ # cls.
205
+ num_cls = config.num_cls
206
+ cls_module = list()
207
+ for _ in range(num_cls):
208
+ cls_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False))
209
+ cls_module.append(nn.LayerNorm(self.hidden_dim))
210
+ cls_module.append(nn.ReLU(inplace=True))
211
+ self.cls_module = nn.ModuleList(cls_module)
212
+
213
+ # reg.
214
+ num_reg = config.num_reg
215
+ reg_module = list()
216
+ for _ in range(num_reg):
217
+ reg_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False))
218
+ reg_module.append(nn.LayerNorm(self.hidden_dim))
219
+ reg_module.append(nn.ReLU(inplace=True))
220
+ self.reg_module = nn.ModuleList(reg_module)
221
+
222
+ # pred.
223
+ self.use_focal = config.use_focal
224
+ self.use_fed_loss = config.use_fed_loss
225
+ if self.use_focal or self.use_fed_loss:
226
+ self.class_logits = nn.Linear(self.hidden_dim, num_classes)
227
+ else:
228
+ self.class_logits = nn.Linear(self.hidden_dim, num_classes + 1)
229
+ self.bboxes_delta = nn.Linear(self.hidden_dim, 4)
230
+ self.scale_clamp = _DEFAULT_SCALE_CLAMP
231
+ self.bbox_weights = (2.0, 2.0, 1.0, 1.0)
232
+
233
+ def forward(self, features, bboxes, time_emb):
234
+ bs, num_proposals = bboxes.shape[:2]
235
+
236
+ # roi_feature.
237
+ roi_features = self.pooler(features, bboxes)
238
+
239
+ pro_features = roi_features.view(bs, num_proposals, self.hidden_dim, -1).mean(-1)
240
+
241
+ roi_features = roi_features.view(bs * num_proposals, self.hidden_dim, -1).permute(2, 0, 1)
242
+
243
+ # self_att.
244
+ pro_features = pro_features.view(bs, num_proposals, self.hidden_dim).permute(1, 0, 2)
245
+ pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0]
246
+ pro_features = pro_features + self.dropout1(pro_features2)
247
+ pro_features = self.norm1(pro_features)
248
+
249
+ # inst_interact.
250
+ pro_features = pro_features.view(num_proposals, bs, self.hidden_dim).permute(1, 0, 2).reshape(1, bs * num_proposals,
251
+ self.hidden_dim)
252
+ pro_features2 = self.inst_interact(pro_features, roi_features)
253
+ pro_features = pro_features + self.dropout2(pro_features2)
254
+ obj_features = self.norm2(pro_features)
255
+
256
+ # obj_feature.
257
+ obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features))))
258
+ obj_features = obj_features + self.dropout3(obj_features2)
259
+ obj_features = self.norm3(obj_features)
260
+
261
+ fc_feature = obj_features.transpose(0, 1).reshape(bs * num_proposals, -1)
262
+
263
+ scale_shift = self.block_time_mlp(time_emb)
264
+ scale_shift = torch.repeat_interleave(scale_shift, num_proposals, dim=0)
265
+ scale, shift = scale_shift.chunk(2, dim=1)
266
+ fc_feature = fc_feature * (scale + 1) + shift
267
+
268
+ cls_feature = fc_feature.clone()
269
+ reg_feature = fc_feature.clone()
270
+ for cls_layer in self.cls_module:
271
+ cls_feature = cls_layer(cls_feature)
272
+ for reg_layer in self.reg_module:
273
+ reg_feature = reg_layer(reg_feature)
274
+ class_logits = self.class_logits(cls_feature)
275
+ bboxes_deltas = self.bboxes_delta(reg_feature)
276
+ pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4))
277
+
278
+ return class_logits.view(bs, num_proposals, -1), pred_bboxes.view(bs, num_proposals, -1), obj_features
279
+
280
+ def apply_deltas(self, deltas, boxes):
281
+ """
282
+ Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
283
+
284
+ Args:
285
+ deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
286
+ deltas[i] represents k potentially different class-specific
287
+ box transformations for the single box boxes[i].
288
+ boxes (Tensor): boxes to transform, of shape (N, 4)
289
+ """
290
+ boxes = boxes.to(deltas.dtype)
291
+
292
+ widths = boxes[:, 2] - boxes[:, 0]
293
+ heights = boxes[:, 3] - boxes[:, 1]
294
+ ctr_x = boxes[:, 0] + 0.5 * widths
295
+ ctr_y = boxes[:, 1] + 0.5 * heights
296
+
297
+ wx, wy, ww, wh = self.bbox_weights
298
+ dx = deltas[:, 0::4] / wx
299
+ dy = deltas[:, 1::4] / wy
300
+ dw = deltas[:, 2::4] / ww
301
+ dh = deltas[:, 3::4] / wh
302
+
303
+ # Prevent sending too large values into torch.exp()
304
+ dw = torch.clamp(dw, max=self.scale_clamp)
305
+ dh = torch.clamp(dh, max=self.scale_clamp)
306
+
307
+ pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
308
+ pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
309
+ pred_w = torch.exp(dw) * widths[:, None]
310
+ pred_h = torch.exp(dh) * heights[:, None]
311
+
312
+ pred_boxes = torch.zeros_like(deltas)
313
+ pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
314
+ pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
315
+ pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
316
+ pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
317
+
318
+ return pred_boxes
319
+
320
+
321
+ class ROIPooler(nn.Module):
322
+ """
323
+ Region of interest feature map pooler that supports pooling from one or more
324
+ feature maps.
325
+ """
326
+
327
+ def __init__(
328
+ self,
329
+ output_size,
330
+ scales,
331
+ sampling_ratio,
332
+ canonical_box_size=224,
333
+ canonical_level=4,
334
+ ):
335
+ super().__init__()
336
+
337
+ min_level = -(math.log2(scales[0]))
338
+ max_level = -(math.log2(scales[-1]))
339
+
340
+ if isinstance(output_size, int):
341
+ output_size = (output_size, output_size)
342
+ assert len(output_size) == 2 and isinstance(output_size[0], int) and isinstance(output_size[1], int)
343
+ assert math.isclose(min_level, int(min_level)) and math.isclose(max_level, int(max_level))
344
+ assert (len(scales) == max_level - min_level + 1)
345
+ assert 0 <= min_level <= max_level
346
+ assert canonical_box_size > 0
347
+
348
+ self.output_size = output_size
349
+ self.min_level = int(min_level)
350
+ self.max_level = int(max_level)
351
+ self.canonical_level = canonical_level
352
+ self.canonical_box_size = canonical_box_size
353
+ self.level_poolers = nn.ModuleList(
354
+ RoIAlign(
355
+ output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=True
356
+ )
357
+ for scale in scales
358
+ )
359
+
360
+ def forward(self, x, bboxes):
361
+ num_level_assignments = len(self.level_poolers)
362
+ assert len(x) == num_level_assignments and len(bboxes) == x[0].size(0)
363
+
364
+ pooler_fmt_boxes = convert_boxes_to_pooler_format(bboxes)
365
+
366
+ if num_level_assignments == 1:
367
+ return self.level_poolers[0](x[0], pooler_fmt_boxes)
368
+
369
+ level_assignments = assign_boxes_to_levels(
370
+ bboxes, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level
371
+ )
372
+
373
+ batches = pooler_fmt_boxes.shape[0]
374
+ channels = x[0].shape[1]
375
+ output_size = self.output_size[0]
376
+ sizes = (batches, channels, output_size, output_size)
377
+
378
+ output = torch.zeros(sizes, dtype=x[0].dtype, device=x[0].device)
379
+
380
+ for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)):
381
+ inds = (level_assignments == level).nonzero(as_tuple=True)[0]
382
+ pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
383
+ # Use index_put_ instead of advance indexing, to avoid pytorch/issues/49852
384
+ output.index_put_((inds,), pooler(x_level, pooler_fmt_boxes_level))
385
+
386
+ return output
image_processing_diffusiondet.py ADDED
@@ -0,0 +1,1624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for DiffusionDet."""
16
+
17
+ import io
18
+ import pathlib
19
+ from collections import defaultdict
20
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
21
+
22
+ import numpy as np
23
+ from transformers.feature_extraction_utils import BatchFeature
24
+ from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
25
+ from transformers.image_transforms import (
26
+ PaddingMode,
27
+ center_to_corners_format,
28
+ corners_to_center_format,
29
+ id_to_rgb,
30
+ pad,
31
+ rescale,
32
+ resize,
33
+ rgb_to_id,
34
+ to_channel_dimension_format,
35
+ )
36
+
37
+ from transformers.image_utils import (
38
+ IMAGENET_DEFAULT_MEAN,
39
+ IMAGENET_DEFAULT_STD,
40
+ AnnotationFormat,
41
+ AnnotationType,
42
+ ChannelDimension,
43
+ ImageInput,
44
+ PILImageResampling,
45
+ get_image_size,
46
+ infer_channel_dimension_format,
47
+ is_scaled_image,
48
+ make_list_of_images,
49
+ to_numpy_array,
50
+ valid_images,
51
+ validate_annotations,
52
+ validate_kwargs,
53
+ validate_preprocess_arguments
54
+ )
55
+
56
+ from transformers.utils import (
57
+ TensorType,
58
+ is_flax_available,
59
+ is_jax_tensor,
60
+ is_tf_available,
61
+ is_tf_tensor,
62
+ is_torch_tensor,
63
+ is_vision_available
64
+ )
65
+ from transformers.utils import (
66
+ is_torch_available,
67
+ is_scipy_available,
68
+ logging
69
+ )
70
+
71
+
72
+ if is_torch_available():
73
+ import torch
74
+ from torch import nn
75
+
76
+ if is_vision_available():
77
+ import PIL
78
+
79
+ if is_scipy_available():
80
+ import scipy.special
81
+ import scipy.stats
82
+
83
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
84
+
85
+ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
86
+
87
+
88
+ # Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio
89
+ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, int]:
90
+ """
91
+ Computes the output image size given the input image size and the desired output size.
92
+
93
+ Args:
94
+ image_size (`Tuple[int, int]`):
95
+ The input image size.
96
+ size (`int`):
97
+ The desired output size.
98
+ max_size (`int`, *optional*):
99
+ The maximum allowed output size.
100
+ """
101
+ height, width = image_size
102
+ raw_size = None
103
+ if max_size is not None:
104
+ min_original_size = float(min((height, width)))
105
+ max_original_size = float(max((height, width)))
106
+ if max_original_size / min_original_size * size > max_size:
107
+ raw_size = max_size * min_original_size / max_original_size
108
+ size = int(round(raw_size))
109
+
110
+ if (height <= width and height == size) or (width <= height and width == size):
111
+ oh, ow = height, width
112
+ elif width < height:
113
+ ow = size
114
+ if max_size is not None and raw_size is not None:
115
+ oh = int(raw_size * height / width)
116
+ else:
117
+ oh = int(size * height / width)
118
+ else:
119
+ oh = size
120
+ if max_size is not None and raw_size is not None:
121
+ ow = int(raw_size * width / height)
122
+ else:
123
+ ow = int(size * width / height)
124
+
125
+ return (oh, ow)
126
+
127
+
128
+ # Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
129
+ def get_resize_output_image_size(
130
+ input_image: np.ndarray,
131
+ size: Union[int, Tuple[int, int], List[int]],
132
+ max_size: Optional[int] = None,
133
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
134
+ ) -> Tuple[int, int]:
135
+ """
136
+ Computes the output image size given the input image size and the desired output size. If the desired output size
137
+ is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
138
+ image size is computed by keeping the aspect ratio of the input image size.
139
+
140
+ Args:
141
+ input_image (`np.ndarray`):
142
+ The image to resize.
143
+ size (`int` or `Tuple[int, int]` or `List[int]`):
144
+ The desired output size.
145
+ max_size (`int`, *optional*):
146
+ The maximum allowed output size.
147
+ input_data_format (`ChannelDimension` or `str`, *optional*):
148
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
149
+ """
150
+ image_size = get_image_size(input_image, input_data_format)
151
+ if isinstance(size, (list, tuple)):
152
+ return size
153
+
154
+ return get_size_with_aspect_ratio(image_size, size, max_size)
155
+
156
+
157
+ # Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width
158
+ def get_image_size_for_max_height_width(
159
+ input_image: np.ndarray,
160
+ max_height: int,
161
+ max_width: int,
162
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
163
+ ) -> Tuple[int, int]:
164
+ """
165
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
166
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
167
+ to at least one of the edges be equal to max_height or max_width.
168
+
169
+ For example:
170
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
171
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
172
+
173
+ Args:
174
+ input_image (`np.ndarray`):
175
+ The image to resize.
176
+ max_height (`int`):
177
+ The maximum allowed height.
178
+ max_width (`int`):
179
+ The maximum allowed width.
180
+ input_data_format (`ChannelDimension` or `str`, *optional*):
181
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
182
+ """
183
+ image_size = get_image_size(input_image, input_data_format)
184
+ height, width = image_size
185
+ height_scale = max_height / height
186
+ width_scale = max_width / width
187
+ min_scale = min(height_scale, width_scale)
188
+ new_height = int(height * min_scale)
189
+ new_width = int(width * min_scale)
190
+ return new_height, new_width
191
+
192
+
193
+ # Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn
194
+ def get_numpy_to_framework_fn(arr) -> Callable:
195
+ """
196
+ Returns a function that converts a numpy array to the framework of the input array.
197
+
198
+ Args:
199
+ arr (`np.ndarray`): The array to convert.
200
+ """
201
+ if isinstance(arr, np.ndarray):
202
+ return np.array
203
+ if is_torch_available() and is_torch_tensor(arr):
204
+ import torch
205
+
206
+ return torch.tensor
207
+ raise ValueError(f"Cannot convert arrays of type {type(arr)}")
208
+
209
+
210
+ # Copied from transformers.models.detr.image_processing_detr.safe_squeeze
211
+ def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
212
+ """
213
+ Squeezes an array, but only if the axis specified has dim 1.
214
+ """
215
+ if axis is None:
216
+ return arr.squeeze()
217
+
218
+ try:
219
+ return arr.squeeze(axis=axis)
220
+ except ValueError:
221
+ return arr
222
+
223
+
224
+ # Copied from transformers.models.detr.image_processing_detr.normalize_annotation
225
+ def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict:
226
+ image_height, image_width = image_size
227
+ norm_annotation = {}
228
+ for key, value in annotation.items():
229
+ if key == "boxes":
230
+ boxes = value
231
+ boxes = corners_to_center_format(boxes)
232
+ boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
233
+ norm_annotation[key] = boxes
234
+ else:
235
+ norm_annotation[key] = value
236
+ return norm_annotation
237
+
238
+
239
+ # Copied from transformers.models.detr.image_processing_detr.max_across_indices
240
+ def max_across_indices(values: Iterable[Any]) -> List[Any]:
241
+ """
242
+ Return the maximum value across all indices of an iterable of values.
243
+ """
244
+ return [max(values_i) for values_i in zip(*values)]
245
+
246
+
247
+ # Copied from transformers.models.detr.image_processing_detr.get_max_height_width
248
+ def get_max_height_width(
249
+ images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
250
+ ) -> List[int]:
251
+ """
252
+ Get the maximum height and width across all images in a batch.
253
+ """
254
+ if input_data_format is None:
255
+ input_data_format = infer_channel_dimension_format(images[0])
256
+
257
+ if input_data_format == ChannelDimension.FIRST:
258
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
259
+ elif input_data_format == ChannelDimension.LAST:
260
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
261
+ else:
262
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
263
+ return (max_height, max_width)
264
+
265
+
266
+ # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
267
+ def make_pixel_mask(
268
+ image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
269
+ ) -> np.ndarray:
270
+ """
271
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
272
+
273
+ Args:
274
+ image (`np.ndarray`):
275
+ Image to make the pixel mask for.
276
+ output_size (`Tuple[int, int]`):
277
+ Output size of the mask.
278
+ """
279
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
280
+ mask = np.zeros(output_size, dtype=np.int64)
281
+ mask[:input_height, :input_width] = 1
282
+ return mask
283
+
284
+
285
+ # Copied from transformers.models.detr.image_processing_detr.convert_coco_poly_to_mask
286
+ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray:
287
+ """
288
+ Convert a COCO polygon annotation to a mask.
289
+
290
+ Args:
291
+ segmentations (`List[List[float]]`):
292
+ List of polygons, each polygon represented by a list of x-y coordinates.
293
+ height (`int`):
294
+ Height of the mask.
295
+ width (`int`):
296
+ Width of the mask.
297
+ """
298
+ try:
299
+ from pycocotools import mask as coco_mask
300
+ except ImportError:
301
+ raise ImportError("Pycocotools is not installed in your environment.")
302
+
303
+ masks = []
304
+ for polygons in segmentations:
305
+ rles = coco_mask.frPyObjects(polygons, height, width)
306
+ mask = coco_mask.decode(rles)
307
+ if len(mask.shape) < 3:
308
+ mask = mask[..., None]
309
+ mask = np.asarray(mask, dtype=np.uint8)
310
+ mask = np.any(mask, axis=2)
311
+ masks.append(mask)
312
+ if masks:
313
+ masks = np.stack(masks, axis=0)
314
+ else:
315
+ masks = np.zeros((0, height, width), dtype=np.uint8)
316
+
317
+ return masks
318
+
319
+
320
+ # Copied from transformers.models.detr.image_processing_detr.prepare_coco_detection_annotation with DETR->DeformableDetr
321
+ def prepare_coco_detection_annotation(
322
+ image,
323
+ target,
324
+ return_segmentation_masks: bool = False,
325
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
326
+ ):
327
+ """
328
+ Convert the target in COCO format into the format expected by DeformableDetr.
329
+ """
330
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
331
+
332
+ image_id = target["image_id"]
333
+ image_id = np.asarray([image_id], dtype=np.int64)
334
+
335
+ # Get all COCO annotations for the given image.
336
+ annotations = target["annotations"]
337
+ annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
338
+
339
+ classes = [obj["category_id"] for obj in annotations]
340
+ classes = np.asarray(classes, dtype=np.int64)
341
+
342
+ # for conversion to coco api
343
+ area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
344
+ iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=np.int64)
345
+
346
+ boxes = [obj["bbox"] for obj in annotations]
347
+ # guard against no boxes via resizing
348
+ boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
349
+ boxes[:, 2:] += boxes[:, :2]
350
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
351
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
352
+
353
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
354
+
355
+ new_target = {}
356
+ new_target["image_id"] = image_id
357
+ new_target["class_labels"] = classes[keep]
358
+ new_target["boxes"] = boxes[keep]
359
+ new_target["area"] = area[keep]
360
+ new_target["iscrowd"] = iscrowd[keep]
361
+ new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
362
+
363
+ if annotations and "keypoints" in annotations[0]:
364
+ keypoints = [obj["keypoints"] for obj in annotations]
365
+ # Converting the filtered keypoints list to a numpy array
366
+ keypoints = np.asarray(keypoints, dtype=np.float32)
367
+ # Apply the keep mask here to filter the relevant annotations
368
+ keypoints = keypoints[keep]
369
+ num_keypoints = keypoints.shape[0]
370
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
371
+ new_target["keypoints"] = keypoints
372
+
373
+ if return_segmentation_masks:
374
+ segmentation_masks = [obj["segmentation"] for obj in annotations]
375
+ masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width)
376
+ new_target["masks"] = masks[keep]
377
+
378
+ return new_target
379
+
380
+
381
+ # Copied from transformers.models.detr.image_processing_detr.masks_to_boxes
382
+ def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
383
+ """
384
+ Compute the bounding boxes around the provided panoptic segmentation masks.
385
+
386
+ Args:
387
+ masks: masks in format `[number_masks, height, width]` where N is the number of masks
388
+
389
+ Returns:
390
+ boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
391
+ """
392
+ if masks.size == 0:
393
+ return np.zeros((0, 4))
394
+
395
+ h, w = masks.shape[-2:]
396
+ y = np.arange(0, h, dtype=np.float32)
397
+ x = np.arange(0, w, dtype=np.float32)
398
+ # see https://github.com/pytorch/pytorch/issues/50276
399
+ y, x = np.meshgrid(y, x, indexing="ij")
400
+
401
+ x_mask = masks * np.expand_dims(x, axis=0)
402
+ x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1)
403
+ x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool)))
404
+ x_min = x.filled(fill_value=1e8)
405
+ x_min = x_min.reshape(x_min.shape[0], -1).min(-1)
406
+
407
+ y_mask = masks * np.expand_dims(y, axis=0)
408
+ y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1)
409
+ y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool)))
410
+ y_min = y.filled(fill_value=1e8)
411
+ y_min = y_min.reshape(y_min.shape[0], -1).min(-1)
412
+
413
+ return np.stack([x_min, y_min, x_max, y_max], 1)
414
+
415
+
416
+ # Copied from transformers.models.detr.image_processing_detr.prepare_coco_panoptic_annotation with DETR->DeformableDetr
417
+ def prepare_coco_panoptic_annotation(
418
+ image: np.ndarray,
419
+ target: Dict,
420
+ masks_path: Union[str, pathlib.Path],
421
+ return_masks: bool = True,
422
+ input_data_format: Union[ChannelDimension, str] = None,
423
+ ) -> Dict:
424
+ """
425
+ Prepare a coco panoptic annotation for DeformableDetr.
426
+ """
427
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
428
+ annotation_path = pathlib.Path(masks_path) / target["file_name"]
429
+
430
+ new_target = {}
431
+ new_target["image_id"] = np.asarray([target["image_id"] if "image_id" in target else target["id"]], dtype=np.int64)
432
+ new_target["size"] = np.asarray([image_height, image_width], dtype=np.int64)
433
+ new_target["orig_size"] = np.asarray([image_height, image_width], dtype=np.int64)
434
+
435
+ if "segments_info" in target:
436
+ masks = np.asarray(PIL.Image.open(annotation_path), dtype=np.uint32)
437
+ masks = rgb_to_id(masks)
438
+
439
+ ids = np.array([segment_info["id"] for segment_info in target["segments_info"]])
440
+ masks = masks == ids[:, None, None]
441
+ masks = masks.astype(np.uint8)
442
+ if return_masks:
443
+ new_target["masks"] = masks
444
+ new_target["boxes"] = masks_to_boxes(masks)
445
+ new_target["class_labels"] = np.array(
446
+ [segment_info["category_id"] for segment_info in target["segments_info"]], dtype=np.int64
447
+ )
448
+ new_target["iscrowd"] = np.asarray(
449
+ [segment_info["iscrowd"] for segment_info in target["segments_info"]], dtype=np.int64
450
+ )
451
+ new_target["area"] = np.asarray(
452
+ [segment_info["area"] for segment_info in target["segments_info"]], dtype=np.float32
453
+ )
454
+
455
+ return new_target
456
+
457
+
458
+ # Copied from transformers.models.detr.image_processing_detr.get_segmentation_image
459
+ def get_segmentation_image(
460
+ masks: np.ndarray, input_size: Tuple, target_size: Tuple, stuff_equiv_classes, deduplicate=False
461
+ ):
462
+ h, w = input_size
463
+ final_h, final_w = target_size
464
+
465
+ m_id = scipy.special.softmax(masks.transpose(0, 1), -1)
466
+
467
+ if m_id.shape[-1] == 0:
468
+ # We didn't detect any mask :(
469
+ m_id = np.zeros((h, w), dtype=np.int64)
470
+ else:
471
+ m_id = m_id.argmax(-1).reshape(h, w)
472
+
473
+ if deduplicate:
474
+ # Merge the masks corresponding to the same stuff class
475
+ for equiv in stuff_equiv_classes.values():
476
+ for eq_id in equiv:
477
+ m_id[m_id == eq_id] = equiv[0]
478
+
479
+ seg_img = id_to_rgb(m_id)
480
+ seg_img = resize(seg_img, (final_w, final_h), resample=PILImageResampling.NEAREST)
481
+ return seg_img
482
+
483
+
484
+ # Copied from transformers.models.detr.image_processing_detr.get_mask_area
485
+ def get_mask_area(seg_img: np.ndarray, target_size: Tuple[int, int], n_classes: int) -> np.ndarray:
486
+ final_h, final_w = target_size
487
+ np_seg_img = seg_img.astype(np.uint8)
488
+ np_seg_img = np_seg_img.reshape(final_h, final_w, 3)
489
+ m_id = rgb_to_id(np_seg_img)
490
+ area = [(m_id == i).sum() for i in range(n_classes)]
491
+ return area
492
+
493
+
494
+ # Copied from transformers.models.detr.image_processing_detr.score_labels_from_class_probabilities
495
+ def score_labels_from_class_probabilities(logits: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
496
+ probs = scipy.special.softmax(logits, axis=-1)
497
+ labels = probs.argmax(-1, keepdims=True)
498
+ scores = np.take_along_axis(probs, labels, axis=-1)
499
+ scores, labels = scores.squeeze(-1), labels.squeeze(-1)
500
+ return scores, labels
501
+
502
+
503
+ # Copied from transformers.models.detr.image_processing_detr.post_process_panoptic_sample
504
+ def post_process_panoptic_sample(
505
+ out_logits: np.ndarray,
506
+ masks: np.ndarray,
507
+ boxes: np.ndarray,
508
+ processed_size: Tuple[int, int],
509
+ target_size: Tuple[int, int],
510
+ is_thing_map: Dict,
511
+ threshold=0.85,
512
+ ) -> Dict:
513
+ """
514
+ Converts the output of [`DetrForSegmentation`] into panoptic segmentation predictions for a single sample.
515
+
516
+ Args:
517
+ out_logits (`torch.Tensor`):
518
+ The logits for this sample.
519
+ masks (`torch.Tensor`):
520
+ The predicted segmentation masks for this sample.
521
+ boxes (`torch.Tensor`):
522
+ The prediced bounding boxes for this sample. The boxes are in the normalized format `(center_x, center_y,
523
+ width, height)` and values between `[0, 1]`, relative to the size the image (disregarding padding).
524
+ processed_size (`Tuple[int, int]`):
525
+ The processed size of the image `(height, width)`, as returned by the preprocessing step i.e. the size
526
+ after data augmentation but before batching.
527
+ target_size (`Tuple[int, int]`):
528
+ The target size of the image, `(height, width)` corresponding to the requested final size of the
529
+ prediction.
530
+ is_thing_map (`Dict`):
531
+ A dictionary mapping class indices to a boolean value indicating whether the class is a thing or not.
532
+ threshold (`float`, *optional*, defaults to 0.85):
533
+ The threshold used to binarize the segmentation masks.
534
+ """
535
+ # we filter empty queries and detection below threshold
536
+ scores, labels = score_labels_from_class_probabilities(out_logits)
537
+ keep = (labels != out_logits.shape[-1] - 1) & (scores > threshold)
538
+
539
+ cur_scores = scores[keep]
540
+ cur_classes = labels[keep]
541
+ cur_boxes = center_to_corners_format(boxes[keep])
542
+
543
+ if len(cur_boxes) != len(cur_classes):
544
+ raise ValueError("Not as many boxes as there are classes")
545
+
546
+ cur_masks = masks[keep]
547
+ cur_masks = resize(cur_masks[:, None], processed_size, resample=PILImageResampling.BILINEAR)
548
+ cur_masks = safe_squeeze(cur_masks, 1)
549
+ b, h, w = cur_masks.shape
550
+
551
+ # It may be that we have several predicted masks for the same stuff class.
552
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
553
+ cur_masks = cur_masks.reshape(b, -1)
554
+ stuff_equiv_classes = defaultdict(list)
555
+ for k, label in enumerate(cur_classes):
556
+ if not is_thing_map[label]:
557
+ stuff_equiv_classes[label].append(k)
558
+
559
+ seg_img = get_segmentation_image(cur_masks, processed_size, target_size, stuff_equiv_classes, deduplicate=True)
560
+ area = get_mask_area(cur_masks, processed_size, n_classes=len(cur_scores))
561
+
562
+ # We filter out any mask that is too small
563
+ if cur_classes.size() > 0:
564
+ # We know filter empty masks as long as we find some
565
+ filtered_small = np.array([a <= 4 for a in area], dtype=bool)
566
+ while filtered_small.any():
567
+ cur_masks = cur_masks[~filtered_small]
568
+ cur_scores = cur_scores[~filtered_small]
569
+ cur_classes = cur_classes[~filtered_small]
570
+ seg_img = get_segmentation_image(cur_masks, (h, w), target_size, stuff_equiv_classes, deduplicate=True)
571
+ area = get_mask_area(seg_img, target_size, n_classes=len(cur_scores))
572
+ filtered_small = np.array([a <= 4 for a in area], dtype=bool)
573
+ else:
574
+ cur_classes = np.ones((1, 1), dtype=np.int64)
575
+
576
+ segments_info = [
577
+ {"id": i, "isthing": is_thing_map[cat], "category_id": int(cat), "area": a}
578
+ for i, (cat, a) in enumerate(zip(cur_classes, area))
579
+ ]
580
+ del cur_classes
581
+
582
+ with io.BytesIO() as out:
583
+ PIL.Image.fromarray(seg_img).save(out, format="PNG")
584
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
585
+
586
+ return predictions
587
+
588
+
589
+ # Copied from transformers.models.detr.image_processing_detr.resize_annotation
590
+ def resize_annotation(
591
+ annotation: Dict[str, Any],
592
+ orig_size: Tuple[int, int],
593
+ target_size: Tuple[int, int],
594
+ threshold: float = 0.5,
595
+ resample: PILImageResampling = PILImageResampling.NEAREST,
596
+ ):
597
+ """
598
+ Resizes an annotation to a target size.
599
+
600
+ Args:
601
+ annotation (`Dict[str, Any]`):
602
+ The annotation dictionary.
603
+ orig_size (`Tuple[int, int]`):
604
+ The original size of the input image.
605
+ target_size (`Tuple[int, int]`):
606
+ The target size of the image, as returned by the preprocessing `resize` step.
607
+ threshold (`float`, *optional*, defaults to 0.5):
608
+ The threshold used to binarize the segmentation masks.
609
+ resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
610
+ The resampling filter to use when resizing the masks.
611
+ """
612
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
613
+ ratio_height, ratio_width = ratios
614
+
615
+ new_annotation = {}
616
+ new_annotation["size"] = target_size
617
+
618
+ for key, value in annotation.items():
619
+ if key == "boxes":
620
+ boxes = value
621
+ scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
622
+ new_annotation["boxes"] = scaled_boxes
623
+ elif key == "area":
624
+ area = value
625
+ scaled_area = area * (ratio_width * ratio_height)
626
+ new_annotation["area"] = scaled_area
627
+ elif key == "masks":
628
+ masks = value[:, None]
629
+ masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
630
+ masks = masks.astype(np.float32)
631
+ masks = masks[:, 0] > threshold
632
+ new_annotation["masks"] = masks
633
+ elif key == "size":
634
+ new_annotation["size"] = target_size
635
+ else:
636
+ new_annotation[key] = value
637
+
638
+ return new_annotation
639
+
640
+
641
+ # Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle
642
+ def binary_mask_to_rle(mask):
643
+ """
644
+ Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
645
+
646
+ Args:
647
+ mask (`torch.Tensor` or `numpy.array`):
648
+ A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
649
+ segment_id or class_id.
650
+ Returns:
651
+ `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
652
+ format.
653
+ """
654
+ if is_torch_tensor(mask):
655
+ mask = mask.numpy()
656
+
657
+ pixels = mask.flatten()
658
+ pixels = np.concatenate([[0], pixels, [0]])
659
+ runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
660
+ runs[1::2] -= runs[::2]
661
+ return list(runs)
662
+
663
+
664
+ # Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
665
+ def convert_segmentation_to_rle(segmentation):
666
+ """
667
+ Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
668
+
669
+ Args:
670
+ segmentation (`torch.Tensor` or `numpy.array`):
671
+ A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
672
+ Returns:
673
+ `List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
674
+ """
675
+ segment_ids = torch.unique(segmentation)
676
+
677
+ run_length_encodings = []
678
+ for idx in segment_ids:
679
+ mask = torch.where(segmentation == idx, 1, 0)
680
+ rle = binary_mask_to_rle(mask)
681
+ run_length_encodings.append(rle)
682
+
683
+ return run_length_encodings
684
+
685
+
686
+ # Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects
687
+ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
688
+ """
689
+ Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
690
+ `labels`.
691
+
692
+ Args:
693
+ masks (`torch.Tensor`):
694
+ A tensor of shape `(num_queries, height, width)`.
695
+ scores (`torch.Tensor`):
696
+ A tensor of shape `(num_queries)`.
697
+ labels (`torch.Tensor`):
698
+ A tensor of shape `(num_queries)`.
699
+ object_mask_threshold (`float`):
700
+ A number between 0 and 1 used to binarize the masks.
701
+ Raises:
702
+ `ValueError`: Raised when the first dimension doesn't match in all input tensors.
703
+ Returns:
704
+ `Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
705
+ < `object_mask_threshold`.
706
+ """
707
+ if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
708
+ raise ValueError("mask, scores and labels must have the same shape!")
709
+
710
+ to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
711
+
712
+ return masks[to_keep], scores[to_keep], labels[to_keep]
713
+
714
+
715
+ # Copied from transformers.models.detr.image_processing_detr.check_segment_validity
716
+ def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
717
+ # Get the mask associated with the k class
718
+ mask_k = mask_labels == k
719
+ mask_k_area = mask_k.sum()
720
+
721
+ # Compute the area of all the stuff in query k
722
+ original_area = (mask_probs[k] >= mask_threshold).sum()
723
+ mask_exists = mask_k_area > 0 and original_area > 0
724
+
725
+ # Eliminate disconnected tiny segments
726
+ if mask_exists:
727
+ area_ratio = mask_k_area / original_area
728
+ if not area_ratio.item() > overlap_mask_area_threshold:
729
+ mask_exists = False
730
+
731
+ return mask_exists, mask_k
732
+
733
+
734
+ # Copied from transformers.models.detr.image_processing_detr.compute_segments
735
+ def compute_segments(
736
+ mask_probs,
737
+ pred_scores,
738
+ pred_labels,
739
+ mask_threshold: float = 0.5,
740
+ overlap_mask_area_threshold: float = 0.8,
741
+ label_ids_to_fuse: Optional[Set[int]] = None,
742
+ target_size: Tuple[int, int] = None,
743
+ ):
744
+ height = mask_probs.shape[1] if target_size is None else target_size[0]
745
+ width = mask_probs.shape[2] if target_size is None else target_size[1]
746
+
747
+ segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
748
+ segments: List[Dict] = []
749
+
750
+ if target_size is not None:
751
+ mask_probs = nn.functional.interpolate(
752
+ mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
753
+ )[0]
754
+
755
+ current_segment_id = 0
756
+
757
+ # Weigh each mask by its prediction score
758
+ mask_probs *= pred_scores.view(-1, 1, 1)
759
+ mask_labels = mask_probs.argmax(0) # [height, width]
760
+
761
+ # Keep track of instances of each class
762
+ stuff_memory_list: Dict[str, int] = {}
763
+ for k in range(pred_labels.shape[0]):
764
+ pred_class = pred_labels[k].item()
765
+ should_fuse = pred_class in label_ids_to_fuse
766
+
767
+ # Check if mask exists and large enough to be a segment
768
+ mask_exists, mask_k = check_segment_validity(
769
+ mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
770
+ )
771
+
772
+ if mask_exists:
773
+ if pred_class in stuff_memory_list:
774
+ current_segment_id = stuff_memory_list[pred_class]
775
+ else:
776
+ current_segment_id += 1
777
+
778
+ # Add current object segment to final segmentation map
779
+ segmentation[mask_k] = current_segment_id
780
+ segment_score = round(pred_scores[k].item(), 6)
781
+ segments.append(
782
+ {
783
+ "id": current_segment_id,
784
+ "label_id": pred_class,
785
+ "was_fused": should_fuse,
786
+ "score": segment_score,
787
+ }
788
+ )
789
+ if should_fuse:
790
+ stuff_memory_list[pred_class] = current_segment_id
791
+
792
+ return segmentation, segments
793
+
794
+
795
+ class DiffusionDetImageProcessor(BaseImageProcessor):
796
+ r"""
797
+ Constructs a DiffusionDet image processor.
798
+
799
+ Args:
800
+ format (`str`, *optional*, defaults to `"coco_detection"`):
801
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
802
+ do_resize (`bool`, *optional*, defaults to `True`):
803
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be
804
+ overridden by the `do_resize` parameter in the `preprocess` method.
805
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`):
806
+ Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
807
+ in the `preprocess` method. Available options are:
808
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
809
+ Do NOT keep the aspect ratio.
810
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
811
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
812
+ less or equal to `longest_edge`.
813
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
814
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
815
+ `max_width`.
816
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
817
+ Resampling filter to use if resizing the image.
818
+ do_rescale (`bool`, *optional*, defaults to `True`):
819
+ Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
820
+ `do_rescale` parameter in the `preprocess` method.
821
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
822
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
823
+ `preprocess` method.
824
+ do_normalize:
825
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
826
+ `preprocess` method.
827
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
828
+ Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
829
+ channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
830
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
831
+ Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
832
+ for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
833
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
834
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
835
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
836
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
837
+ do_pad (`bool`, *optional*, defaults to `True`):
838
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
839
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
840
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
841
+ Otherwise, the image will be padded to the maximum height and width of the batch.
842
+ pad_size (`Dict[str, int]`, *optional*):
843
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
844
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
845
+ height and width in the batch.
846
+ """
847
+
848
+ model_input_names = ["pixel_values", "pixel_mask"]
849
+
850
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.__init__
851
+ def __init__(
852
+ self,
853
+ format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
854
+ do_resize: bool = True,
855
+ size: Dict[str, int] = None,
856
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
857
+ do_rescale: bool = True,
858
+ rescale_factor: Union[int, float] = 1 / 255,
859
+ do_normalize: bool = True,
860
+ image_mean: Union[float, List[float]] = None,
861
+ image_std: Union[float, List[float]] = None,
862
+ do_convert_annotations: Optional[bool] = None,
863
+ do_pad: bool = True,
864
+ pad_size: Optional[Dict[str, int]] = None,
865
+ **kwargs,
866
+ ) -> None:
867
+ if "pad_and_return_pixel_mask" in kwargs:
868
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
869
+
870
+ if "max_size" in kwargs:
871
+ logger.warning_once(
872
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
873
+ "Please specify in `size['longest_edge'] instead`.",
874
+ )
875
+ max_size = kwargs.pop("max_size")
876
+ else:
877
+ max_size = None if size is None else 1333
878
+
879
+ size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
880
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
881
+
882
+ # Backwards compatibility
883
+ if do_convert_annotations is None:
884
+ do_convert_annotations = do_normalize
885
+
886
+ super().__init__(**kwargs)
887
+ self.format = format
888
+ self.do_resize = do_resize
889
+ self.size = size
890
+ self.resample = resample
891
+ self.do_rescale = do_rescale
892
+ self.rescale_factor = rescale_factor
893
+ self.do_normalize = do_normalize
894
+ self.do_convert_annotations = do_convert_annotations
895
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
896
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
897
+ self.do_pad = do_pad
898
+ self.pad_size = pad_size
899
+ self._valid_processor_keys = [
900
+ "images",
901
+ "annotations",
902
+ "return_segmentation_masks",
903
+ "masks_path",
904
+ "do_resize",
905
+ "size",
906
+ "resample",
907
+ "do_rescale",
908
+ "rescale_factor",
909
+ "do_normalize",
910
+ "do_convert_annotations",
911
+ "image_mean",
912
+ "image_std",
913
+ "do_pad",
914
+ "pad_size",
915
+ "format",
916
+ "return_tensors",
917
+ "data_format",
918
+ "input_data_format",
919
+ ]
920
+
921
+ @classmethod
922
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->DeformableDetr
923
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
924
+ """
925
+ Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
926
+ created using from_dict and kwargs e.g. `DeformableDetrImageProcessor.from_pretrained(checkpoint, size=600,
927
+ max_size=800)`
928
+ """
929
+ image_processor_dict = image_processor_dict.copy()
930
+ if "max_size" in kwargs:
931
+ image_processor_dict["max_size"] = kwargs.pop("max_size")
932
+ if "pad_and_return_pixel_mask" in kwargs:
933
+ image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
934
+ return super().from_dict(image_processor_dict, **kwargs)
935
+
936
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->DeformableDetr
937
+ def prepare_annotation(
938
+ self,
939
+ image: np.ndarray,
940
+ target: Dict,
941
+ format: Optional[AnnotationFormat] = None,
942
+ return_segmentation_masks: bool = None,
943
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
944
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
945
+ ) -> Dict:
946
+ """
947
+ Prepare an annotation for feeding into DeformableDetr model.
948
+ """
949
+ format = format if format is not None else self.format
950
+
951
+ if format == AnnotationFormat.COCO_DETECTION:
952
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
953
+ target = prepare_coco_detection_annotation(
954
+ image, target, return_segmentation_masks, input_data_format=input_data_format
955
+ )
956
+ elif format == AnnotationFormat.COCO_PANOPTIC:
957
+ return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
958
+ target = prepare_coco_panoptic_annotation(
959
+ image,
960
+ target,
961
+ masks_path=masks_path,
962
+ return_masks=return_segmentation_masks,
963
+ input_data_format=input_data_format,
964
+ )
965
+ else:
966
+ raise ValueError(f"Format {format} is not supported.")
967
+ return target
968
+
969
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize
970
+ def resize(
971
+ self,
972
+ image: np.ndarray,
973
+ size: Dict[str, int],
974
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
975
+ data_format: Optional[ChannelDimension] = None,
976
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
977
+ **kwargs,
978
+ ) -> np.ndarray:
979
+ """
980
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
981
+ int, smaller edge of the image will be matched to this number.
982
+
983
+ Args:
984
+ image (`np.ndarray`):
985
+ Image to resize.
986
+ size (`Dict[str, int]`):
987
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
988
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
989
+ Do NOT keep the aspect ratio.
990
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
991
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
992
+ less or equal to `longest_edge`.
993
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
994
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
995
+ `max_width`.
996
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
997
+ Resampling filter to use if resizing the image.
998
+ data_format (`str` or `ChannelDimension`, *optional*):
999
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
1000
+ image is used.
1001
+ input_data_format (`ChannelDimension` or `str`, *optional*):
1002
+ The channel dimension format of the input image. If not provided, it will be inferred.
1003
+ """
1004
+ if "max_size" in kwargs:
1005
+ logger.warning_once(
1006
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
1007
+ "Please specify in `size['longest_edge'] instead`.",
1008
+ )
1009
+ max_size = kwargs.pop("max_size")
1010
+ else:
1011
+ max_size = None
1012
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
1013
+ if "shortest_edge" in size and "longest_edge" in size:
1014
+ new_size = get_resize_output_image_size(
1015
+ image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
1016
+ )
1017
+ elif "max_height" in size and "max_width" in size:
1018
+ new_size = get_image_size_for_max_height_width(
1019
+ image, size["max_height"], size["max_width"], input_data_format=input_data_format
1020
+ )
1021
+ elif "height" in size and "width" in size:
1022
+ new_size = (size["height"], size["width"])
1023
+ else:
1024
+ raise ValueError(
1025
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
1026
+ f" {size.keys()}."
1027
+ )
1028
+ image = resize(
1029
+ image,
1030
+ size=new_size,
1031
+ resample=resample,
1032
+ data_format=data_format,
1033
+ input_data_format=input_data_format,
1034
+ **kwargs,
1035
+ )
1036
+ return image
1037
+
1038
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
1039
+ def resize_annotation(
1040
+ self,
1041
+ annotation,
1042
+ orig_size,
1043
+ size,
1044
+ resample: PILImageResampling = PILImageResampling.NEAREST,
1045
+ ) -> Dict:
1046
+ """
1047
+ Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
1048
+ to this number.
1049
+ """
1050
+ return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
1051
+
1052
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
1053
+ def rescale(
1054
+ self,
1055
+ image: np.ndarray,
1056
+ rescale_factor: float,
1057
+ data_format: Optional[Union[str, ChannelDimension]] = None,
1058
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1059
+ ) -> np.ndarray:
1060
+ """
1061
+ Rescale the image by the given factor. image = image * rescale_factor.
1062
+
1063
+ Args:
1064
+ image (`np.ndarray`):
1065
+ Image to rescale.
1066
+ rescale_factor (`float`):
1067
+ The value to use for rescaling.
1068
+ data_format (`str` or `ChannelDimension`, *optional*):
1069
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
1070
+ image is used. Can be one of:
1071
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1072
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1073
+ input_data_format (`str` or `ChannelDimension`, *optional*):
1074
+ The channel dimension format for the input image. If unset, is inferred from the input image. Can be
1075
+ one of:
1076
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1077
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1078
+ """
1079
+ return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
1080
+
1081
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
1082
+ def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
1083
+ """
1084
+ Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
1085
+ `[center_x, center_y, width, height]` format and from absolute to relative pixel values.
1086
+ """
1087
+ return normalize_annotation(annotation, image_size=image_size)
1088
+
1089
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image
1090
+ def _update_annotation_for_padded_image(
1091
+ self,
1092
+ annotation: Dict,
1093
+ input_image_size: Tuple[int, int],
1094
+ output_image_size: Tuple[int, int],
1095
+ padding,
1096
+ update_bboxes,
1097
+ ) -> Dict:
1098
+ """
1099
+ Update the annotation for a padded image.
1100
+ """
1101
+ new_annotation = {}
1102
+ new_annotation["size"] = output_image_size
1103
+
1104
+ for key, value in annotation.items():
1105
+ if key == "masks":
1106
+ masks = value
1107
+ masks = pad(
1108
+ masks,
1109
+ padding,
1110
+ mode=PaddingMode.CONSTANT,
1111
+ constant_values=0,
1112
+ input_data_format=ChannelDimension.FIRST,
1113
+ )
1114
+ masks = safe_squeeze(masks, 1)
1115
+ new_annotation["masks"] = masks
1116
+ elif key == "boxes" and update_bboxes:
1117
+ boxes = value
1118
+ boxes *= np.asarray(
1119
+ [
1120
+ input_image_size[1] / output_image_size[1],
1121
+ input_image_size[0] / output_image_size[0],
1122
+ input_image_size[1] / output_image_size[1],
1123
+ input_image_size[0] / output_image_size[0],
1124
+ ]
1125
+ )
1126
+ new_annotation["boxes"] = boxes
1127
+ elif key == "size":
1128
+ new_annotation["size"] = output_image_size
1129
+ else:
1130
+ new_annotation[key] = value
1131
+ return new_annotation
1132
+
1133
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
1134
+ def _pad_image(
1135
+ self,
1136
+ image: np.ndarray,
1137
+ output_size: Tuple[int, int],
1138
+ annotation: Optional[Dict[str, Any]] = None,
1139
+ constant_values: Union[float, Iterable[float]] = 0,
1140
+ data_format: Optional[ChannelDimension] = None,
1141
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1142
+ update_bboxes: bool = True,
1143
+ ) -> np.ndarray:
1144
+ """
1145
+ Pad an image with zeros to the given size.
1146
+ """
1147
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
1148
+ output_height, output_width = output_size
1149
+
1150
+ pad_bottom = output_height - input_height
1151
+ pad_right = output_width - input_width
1152
+ padding = ((0, pad_bottom), (0, pad_right))
1153
+ padded_image = pad(
1154
+ image,
1155
+ padding,
1156
+ mode=PaddingMode.CONSTANT,
1157
+ constant_values=constant_values,
1158
+ data_format=data_format,
1159
+ input_data_format=input_data_format,
1160
+ )
1161
+ if annotation is not None:
1162
+ annotation = self._update_annotation_for_padded_image(
1163
+ annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
1164
+ )
1165
+ return padded_image, annotation
1166
+
1167
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
1168
+ def pad(
1169
+ self,
1170
+ images: List[np.ndarray],
1171
+ annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
1172
+ constant_values: Union[float, Iterable[float]] = 0,
1173
+ return_pixel_mask: bool = True,
1174
+ return_tensors: Optional[Union[str, TensorType]] = None,
1175
+ data_format: Optional[ChannelDimension] = None,
1176
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1177
+ update_bboxes: bool = True,
1178
+ pad_size: Optional[Dict[str, int]] = None,
1179
+ ) -> BatchFeature:
1180
+ """
1181
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
1182
+ in the batch and optionally returns their corresponding pixel mask.
1183
+
1184
+ Args:
1185
+ images (List[`np.ndarray`]):
1186
+ Images to pad.
1187
+ annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
1188
+ Annotations to transform according to the padding that is applied to the images.
1189
+ constant_values (`float` or `Iterable[float]`, *optional*):
1190
+ The value to use for the padding if `mode` is `"constant"`.
1191
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
1192
+ Whether to return a pixel mask.
1193
+ return_tensors (`str` or `TensorType`, *optional*):
1194
+ The type of tensors to return. Can be one of:
1195
+ - Unset: Return a list of `np.ndarray`.
1196
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
1197
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
1198
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
1199
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
1200
+ data_format (`str` or `ChannelDimension`, *optional*):
1201
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
1202
+ input_data_format (`ChannelDimension` or `str`, *optional*):
1203
+ The channel dimension format of the input image. If not provided, it will be inferred.
1204
+ update_bboxes (`bool`, *optional*, defaults to `True`):
1205
+ Whether to update the bounding boxes in the annotations to match the padded images. If the
1206
+ bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
1207
+ format, the bounding boxes will not be updated.
1208
+ pad_size (`Dict[str, int]`, *optional*):
1209
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
1210
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
1211
+ height and width in the batch.
1212
+ """
1213
+ pad_size = pad_size if pad_size is not None else self.pad_size
1214
+ if pad_size is not None:
1215
+ padded_size = (pad_size["height"], pad_size["width"])
1216
+ else:
1217
+ padded_size = get_max_height_width(images, input_data_format=input_data_format)
1218
+
1219
+ annotation_list = annotations if annotations is not None else [None] * len(images)
1220
+ padded_images = []
1221
+ padded_annotations = []
1222
+ for image, annotation in zip(images, annotation_list):
1223
+ padded_image, padded_annotation = self._pad_image(
1224
+ image,
1225
+ padded_size,
1226
+ annotation,
1227
+ constant_values=constant_values,
1228
+ data_format=data_format,
1229
+ input_data_format=input_data_format,
1230
+ update_bboxes=update_bboxes,
1231
+ )
1232
+ padded_images.append(padded_image)
1233
+ padded_annotations.append(padded_annotation)
1234
+
1235
+ data = {"pixel_values": padded_images}
1236
+
1237
+ if return_pixel_mask:
1238
+ masks = [
1239
+ make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
1240
+ for image in images
1241
+ ]
1242
+ data["pixel_mask"] = masks
1243
+
1244
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
1245
+
1246
+ if annotations is not None:
1247
+ encoded_inputs["labels"] = [
1248
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
1249
+ ]
1250
+
1251
+ return encoded_inputs
1252
+
1253
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.preprocess
1254
+ def preprocess(
1255
+ self,
1256
+ images: ImageInput,
1257
+ annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
1258
+ return_segmentation_masks: bool = None,
1259
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
1260
+ do_resize: Optional[bool] = None,
1261
+ size: Optional[Dict[str, int]] = None,
1262
+ resample=None, # PILImageResampling
1263
+ do_rescale: Optional[bool] = None,
1264
+ rescale_factor: Optional[Union[int, float]] = None,
1265
+ do_normalize: Optional[bool] = None,
1266
+ do_convert_annotations: Optional[bool] = None,
1267
+ image_mean: Optional[Union[float, List[float]]] = None,
1268
+ image_std: Optional[Union[float, List[float]]] = None,
1269
+ do_pad: Optional[bool] = None,
1270
+ format: Optional[Union[str, AnnotationFormat]] = None,
1271
+ return_tensors: Optional[Union[TensorType, str]] = None,
1272
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
1273
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1274
+ pad_size: Optional[Dict[str, int]] = None,
1275
+ **kwargs,
1276
+ ) -> BatchFeature:
1277
+ """
1278
+ Preprocess an image or a batch of images so that it can be used by the model.
1279
+
1280
+ Args:
1281
+ images (`ImageInput`):
1282
+ Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
1283
+ from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
1284
+ annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
1285
+ List of annotations associated with the image or batch of images. If annotation is for object
1286
+ detection, the annotations should be a dictionary with the following keys:
1287
+ - "image_id" (`int`): The image id.
1288
+ - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
1289
+ dictionary. An image can have no annotations, in which case the list should be empty.
1290
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
1291
+ - "image_id" (`int`): The image id.
1292
+ - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
1293
+ An image can have no segments, in which case the list should be empty.
1294
+ - "file_name" (`str`): The file name of the image.
1295
+ return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
1296
+ Whether to return segmentation masks.
1297
+ masks_path (`str` or `pathlib.Path`, *optional*):
1298
+ Path to the directory containing the segmentation masks.
1299
+ do_resize (`bool`, *optional*, defaults to self.do_resize):
1300
+ Whether to resize the image.
1301
+ size (`Dict[str, int]`, *optional*, defaults to self.size):
1302
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
1303
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
1304
+ Do NOT keep the aspect ratio.
1305
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
1306
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
1307
+ less or equal to `longest_edge`.
1308
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
1309
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
1310
+ `max_width`.
1311
+ resample (`PILImageResampling`, *optional*, defaults to self.resample):
1312
+ Resampling filter to use when resizing the image.
1313
+ do_rescale (`bool`, *optional*, defaults to self.do_rescale):
1314
+ Whether to rescale the image.
1315
+ rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
1316
+ Rescale factor to use when rescaling the image.
1317
+ do_normalize (`bool`, *optional*, defaults to self.do_normalize):
1318
+ Whether to normalize the image.
1319
+ do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
1320
+ Whether to convert the annotations to the format expected by the model. Converts the bounding
1321
+ boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
1322
+ and in relative coordinates.
1323
+ image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean):
1324
+ Mean to use when normalizing the image.
1325
+ image_std (`float` or `List[float]`, *optional*, defaults to self.image_std):
1326
+ Standard deviation to use when normalizing the image.
1327
+ do_pad (`bool`, *optional*, defaults to self.do_pad):
1328
+ Whether to pad the image. If `True`, padding will be applied to the bottom and right of
1329
+ the image with zeros. If `pad_size` is provided, the image will be padded to the specified
1330
+ dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
1331
+ format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
1332
+ Format of the annotations.
1333
+ return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
1334
+ Type of tensors to return. If `None`, will return the list of images.
1335
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
1336
+ The channel dimension format for the output image. Can be one of:
1337
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1338
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1339
+ - Unset: Use the channel dimension format of the input image.
1340
+ input_data_format (`ChannelDimension` or `str`, *optional*):
1341
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
1342
+ from the input image. Can be one of:
1343
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1344
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1345
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
1346
+ pad_size (`Dict[str, int]`, *optional*):
1347
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
1348
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
1349
+ height and width in the batch.
1350
+ """
1351
+ if "pad_and_return_pixel_mask" in kwargs:
1352
+ logger.warning_once(
1353
+ "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
1354
+ "use `do_pad` instead."
1355
+ )
1356
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
1357
+
1358
+ if "max_size" in kwargs:
1359
+ logger.warning_once(
1360
+ "The `max_size` argument is deprecated and will be removed in a future version, use"
1361
+ " `size['longest_edge']` instead."
1362
+ )
1363
+ size = kwargs.pop("max_size")
1364
+
1365
+ do_resize = self.do_resize if do_resize is None else do_resize
1366
+ size = self.size if size is None else size
1367
+ size = get_size_dict(size=size, default_to_square=False)
1368
+ resample = self.resample if resample is None else resample
1369
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
1370
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
1371
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
1372
+ image_mean = self.image_mean if image_mean is None else image_mean
1373
+ image_std = self.image_std if image_std is None else image_std
1374
+ do_convert_annotations = (
1375
+ self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
1376
+ )
1377
+ do_pad = self.do_pad if do_pad is None else do_pad
1378
+ pad_size = self.pad_size if pad_size is None else pad_size
1379
+ format = self.format if format is None else format
1380
+
1381
+ images = make_list_of_images(images)
1382
+
1383
+ if not valid_images(images):
1384
+ raise ValueError(
1385
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
1386
+ "torch.Tensor, tf.Tensor or jax.ndarray."
1387
+ )
1388
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
1389
+
1390
+ # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
1391
+ validate_preprocess_arguments(
1392
+ do_rescale=do_rescale,
1393
+ rescale_factor=rescale_factor,
1394
+ do_normalize=do_normalize,
1395
+ image_mean=image_mean,
1396
+ image_std=image_std,
1397
+ do_resize=do_resize,
1398
+ size=size,
1399
+ resample=resample,
1400
+ )
1401
+
1402
+ if annotations is not None and isinstance(annotations, dict):
1403
+ annotations = [annotations]
1404
+
1405
+ if annotations is not None and len(images) != len(annotations):
1406
+ raise ValueError(
1407
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
1408
+ )
1409
+
1410
+ format = AnnotationFormat(format)
1411
+ if annotations is not None:
1412
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
1413
+
1414
+ if (
1415
+ masks_path is not None
1416
+ and format == AnnotationFormat.COCO_PANOPTIC
1417
+ and not isinstance(masks_path, (pathlib.Path, str))
1418
+ ):
1419
+ raise ValueError(
1420
+ "The path to the directory containing the mask PNG files should be provided as a"
1421
+ f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
1422
+ )
1423
+
1424
+ # All transformations expect numpy arrays
1425
+ images = [to_numpy_array(image) for image in images]
1426
+
1427
+ if is_scaled_image(images[0]) and do_rescale:
1428
+ logger.warning_once(
1429
+ "It looks like you are trying to rescale already rescaled images. If the input"
1430
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
1431
+ )
1432
+
1433
+ if input_data_format is None:
1434
+ # We assume that all images have the same channel dimension format.
1435
+ input_data_format = infer_channel_dimension_format(images[0])
1436
+
1437
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
1438
+ if annotations is not None:
1439
+ prepared_images = []
1440
+ prepared_annotations = []
1441
+ for image, target in zip(images, annotations):
1442
+ target = self.prepare_annotation(
1443
+ image,
1444
+ target,
1445
+ format,
1446
+ return_segmentation_masks=return_segmentation_masks,
1447
+ masks_path=masks_path,
1448
+ input_data_format=input_data_format,
1449
+ )
1450
+ prepared_images.append(image)
1451
+ prepared_annotations.append(target)
1452
+ images = prepared_images
1453
+ annotations = prepared_annotations
1454
+ del prepared_images, prepared_annotations
1455
+
1456
+ # transformations
1457
+ if do_resize:
1458
+ if annotations is not None:
1459
+ resized_images, resized_annotations = [], []
1460
+ for image, target in zip(images, annotations):
1461
+ orig_size = get_image_size(image, input_data_format)
1462
+ resized_image = self.resize(
1463
+ image, size=size, resample=resample, input_data_format=input_data_format
1464
+ )
1465
+ resized_annotation = self.resize_annotation(
1466
+ target, orig_size, get_image_size(resized_image, input_data_format)
1467
+ )
1468
+ resized_images.append(resized_image)
1469
+ resized_annotations.append(resized_annotation)
1470
+ images = resized_images
1471
+ annotations = resized_annotations
1472
+ del resized_images, resized_annotations
1473
+ else:
1474
+ images = [
1475
+ self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
1476
+ for image in images
1477
+ ]
1478
+
1479
+ if do_rescale:
1480
+ images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
1481
+
1482
+ if do_normalize:
1483
+ images = [
1484
+ self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
1485
+ ]
1486
+
1487
+ if do_convert_annotations and annotations is not None:
1488
+ annotations = [
1489
+ self.normalize_annotation(annotation, get_image_size(image, input_data_format))
1490
+ for annotation, image in zip(annotations, images)
1491
+ ]
1492
+
1493
+ if do_pad:
1494
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
1495
+ encoded_inputs = self.pad(
1496
+ images,
1497
+ annotations=annotations,
1498
+ return_pixel_mask=True,
1499
+ data_format=data_format,
1500
+ input_data_format=input_data_format,
1501
+ update_bboxes=do_convert_annotations,
1502
+ return_tensors=return_tensors,
1503
+ pad_size=pad_size,
1504
+ )
1505
+ else:
1506
+ images = [
1507
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
1508
+ for image in images
1509
+ ]
1510
+ encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
1511
+ if annotations is not None:
1512
+ encoded_inputs["labels"] = [
1513
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
1514
+ ]
1515
+
1516
+ return encoded_inputs
1517
+
1518
+ # POSTPROCESSING METHODS - TODO: add support for other frameworks
1519
+ def post_process(self, outputs, target_sizes):
1520
+ """
1521
+ Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
1522
+ top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
1523
+
1524
+ Args:
1525
+ outputs ([`DeformableDetrObjectDetectionOutput`]):
1526
+ Raw outputs of the model.
1527
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
1528
+ Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
1529
+ original image size (before any data augmentation). For visualization, this should be the image size
1530
+ after data augment, but before padding.
1531
+ Returns:
1532
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
1533
+ in the batch as predicted by the model.
1534
+ """
1535
+ logger.warning_once(
1536
+ "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
1537
+ " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
1538
+ )
1539
+
1540
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
1541
+
1542
+ if len(out_logits) != len(target_sizes):
1543
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
1544
+ if target_sizes.shape[1] != 2:
1545
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
1546
+
1547
+ prob = out_logits.sigmoid()
1548
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
1549
+ scores = topk_values
1550
+ topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
1551
+ labels = topk_indexes % out_logits.shape[2]
1552
+ boxes = center_to_corners_format(out_bbox)
1553
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
1554
+
1555
+ # and from relative [0, 1] to absolute [0, height] coordinates
1556
+ img_h, img_w = target_sizes.unbind(1)
1557
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
1558
+ boxes = boxes * scale_fct[:, None, :]
1559
+
1560
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
1561
+
1562
+ return results
1563
+
1564
+ def post_process_object_detection(
1565
+ self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
1566
+ ):
1567
+ """
1568
+ Converts the raw output of [`DiffusionDet`] into final bounding boxes in (top_left_x,
1569
+ top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
1570
+
1571
+ Args:
1572
+ outputs ([`DetrObjectDetectionOutput`]):
1573
+ Raw outputs of the model.
1574
+ threshold (`float`, *optional*):
1575
+ Score threshold to keep object detection predictions.
1576
+ target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
1577
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
1578
+ (height, width) of each image in the batch. If left to None, predictions will not be resized.
1579
+ top_k (`int`, *optional*, defaults to 100):
1580
+ Keep only top k bounding boxes before filtering by thresholding.
1581
+
1582
+ Returns:
1583
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
1584
+ in the batch as predicted by the model.
1585
+ """
1586
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
1587
+
1588
+ if target_sizes is not None:
1589
+ if len(out_logits) != len(target_sizes):
1590
+ raise ValueError(
1591
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1592
+ )
1593
+
1594
+ prob = out_logits.sigmoid()
1595
+ prob = prob.view(out_logits.shape[0], -1)
1596
+ k_value = min(top_k, prob.size(1))
1597
+ topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
1598
+ scores = topk_values
1599
+ topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
1600
+ labels = topk_indexes % out_logits.shape[2]
1601
+ boxes = center_to_corners_format(out_bbox)
1602
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
1603
+
1604
+ # and from relative [0, 1] to absolute [0, height] coordinates
1605
+ if target_sizes is not None:
1606
+ if isinstance(target_sizes, List):
1607
+ img_h = torch.Tensor([i[0] for i in target_sizes])
1608
+ img_w = torch.Tensor([i[1] for i in target_sizes])
1609
+ else:
1610
+ img_h, img_w = target_sizes.unbind(1)
1611
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
1612
+ boxes = boxes * scale_fct[:, None, :]
1613
+
1614
+ results = []
1615
+ for s, l, b in zip(scores, labels, boxes):
1616
+ score = s[s > threshold]
1617
+ label = l[s > threshold]
1618
+ box = b[s > threshold]
1619
+ results.append({"scores": score, "labels": label, "boxes": box})
1620
+
1621
+ return results
1622
+
1623
+
1624
+ __all__ = ["DiffusionDetImageProcessor"]
loss.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from fvcore.nn import sigmoid_focal_loss_jit
4
+ from torch import nn
5
+
6
+ import torch.distributed as dist
7
+ from torch.distributed import get_world_size
8
+ from torchvision import ops
9
+
10
+
11
+ def is_dist_avail_and_initialized():
12
+ if not dist.is_available():
13
+ return False
14
+ if not dist.is_initialized():
15
+ return False
16
+ return True
17
+
18
+
19
+ def get_fed_loss_classes(gt_classes, num_fed_loss_classes, num_classes, weight):
20
+ """
21
+ Args:
22
+ gt_classes: a long tensor of shape R that contains the gt class label of each proposal.
23
+ num_fed_loss_classes: minimum number of classes to keep when calculating federated loss.
24
+ Will sample negative classes if number of unique gt_classes is smaller than this value.
25
+ num_classes: number of foreground classes
26
+ weight: probabilities used to sample negative classes
27
+ Returns:
28
+ Tensor:
29
+ classes to keep when calculating the federated loss, including both unique gt
30
+ classes and sampled negative classes.
31
+ """
32
+ unique_gt_classes = torch.unique(gt_classes)
33
+ prob = unique_gt_classes.new_ones(num_classes + 1).float()
34
+ prob[-1] = 0
35
+ if len(unique_gt_classes) < num_fed_loss_classes:
36
+ prob[:num_classes] = weight.float().clone()
37
+ prob[unique_gt_classes] = 0
38
+ sampled_negative_classes = torch.multinomial(
39
+ prob, num_fed_loss_classes - len(unique_gt_classes), replacement=False
40
+ )
41
+ fed_loss_classes = torch.cat([unique_gt_classes, sampled_negative_classes])
42
+ else:
43
+ fed_loss_classes = unique_gt_classes
44
+ return fed_loss_classes
45
+
46
+
47
+ class CriterionDynamicK(nn.Module):
48
+ """ This class computes the loss for DiffusionDet.
49
+ The process happens in two steps:
50
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
51
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
52
+ """
53
+
54
+ def __init__(self, config, num_classes, weight_dict):
55
+ """ Create the criterion.
56
+ Parameters:
57
+ num_classes: number of object categories, omitting the special no-object category
58
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
59
+ """
60
+ super().__init__()
61
+ self.config = config
62
+ self.num_classes = num_classes
63
+ self.matcher = HungarianMatcherDynamicK(config)
64
+ self.weight_dict = weight_dict
65
+ self.eos_coef = config.no_object_weight
66
+ self.use_focal = config.use_focal
67
+ self.use_fed_loss = config.use_fed_loss
68
+
69
+ if self.use_focal:
70
+ self.focal_loss_alpha = config.alpha
71
+ self.focal_loss_gamma = config.gamma
72
+
73
+ # copy-paste from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/roi_heads/fast_rcnn.py#L356
74
+ def loss_labels(self, outputs, targets, indices):
75
+ """Classification loss (NLL)
76
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
77
+ """
78
+ assert 'pred_logits' in outputs
79
+ src_logits = outputs['pred_logits']
80
+ batch_size = len(targets)
81
+
82
+ # idx = self._get_src_permutation_idx(indices)
83
+ # target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
84
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes,
85
+ dtype=torch.int64, device=src_logits.device)
86
+ src_logits_list = []
87
+ target_classes_o_list = []
88
+ # target_classes[idx] = target_classes_o
89
+ for batch_idx in range(batch_size):
90
+ valid_query = indices[batch_idx][0]
91
+ gt_multi_idx = indices[batch_idx][1]
92
+ if len(gt_multi_idx) == 0:
93
+ continue
94
+ bz_src_logits = src_logits[batch_idx]
95
+ target_classes_o = targets[batch_idx]["labels"]
96
+ target_classes[batch_idx, valid_query] = target_classes_o[gt_multi_idx]
97
+
98
+ src_logits_list.append(bz_src_logits[valid_query])
99
+ target_classes_o_list.append(target_classes_o[gt_multi_idx])
100
+
101
+ if self.use_focal or self.use_fed_loss:
102
+ num_boxes = torch.cat(target_classes_o_list).shape[0] if len(target_classes_o_list) != 0 else 1
103
+
104
+ target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], self.num_classes + 1],
105
+ dtype=src_logits.dtype, layout=src_logits.layout,
106
+ device=src_logits.device)
107
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
108
+
109
+ gt_classes = torch.argmax(target_classes_onehot, dim=-1)
110
+ target_classes_onehot = target_classes_onehot[:, :, :-1]
111
+
112
+ src_logits = src_logits.flatten(0, 1)
113
+ target_classes_onehot = target_classes_onehot.flatten(0, 1)
114
+ if self.use_focal:
115
+ cls_loss = sigmoid_focal_loss_jit(src_logits, target_classes_onehot, alpha=self.focal_loss_alpha,
116
+ gamma=self.focal_loss_gamma, reduction="none")
117
+ else:
118
+ cls_loss = F.binary_cross_entropy_with_logits(src_logits, target_classes_onehot, reduction="none")
119
+ if self.use_fed_loss:
120
+ K = self.num_classes
121
+ N = src_logits.shape[0]
122
+ fed_loss_classes = get_fed_loss_classes(
123
+ gt_classes,
124
+ num_fed_loss_classes=self.fed_loss_num_classes,
125
+ num_classes=K,
126
+ weight=self.fed_loss_cls_weights,
127
+ )
128
+ fed_loss_classes_mask = fed_loss_classes.new_zeros(K + 1)
129
+ fed_loss_classes_mask[fed_loss_classes] = 1
130
+ fed_loss_classes_mask = fed_loss_classes_mask[:K]
131
+ weight = fed_loss_classes_mask.view(1, K).expand(N, K).float()
132
+
133
+ loss_ce = torch.sum(cls_loss * weight) / num_boxes
134
+ else:
135
+ loss_ce = torch.sum(cls_loss) / num_boxes
136
+
137
+ losses = {'loss_ce': loss_ce}
138
+ else:
139
+ raise NotImplementedError
140
+
141
+ return losses
142
+
143
+ def loss_boxes(self, outputs, targets, indices):
144
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
145
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
146
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
147
+ """
148
+ assert 'pred_boxes' in outputs
149
+ # idx = self._get_src_permutation_idx(indices)
150
+ src_boxes = outputs['pred_boxes']
151
+
152
+ batch_size = len(targets)
153
+ pred_box_list = []
154
+ pred_norm_box_list = []
155
+ tgt_box_list = []
156
+ tgt_box_xyxy_list = []
157
+ for batch_idx in range(batch_size):
158
+ valid_query = indices[batch_idx][0]
159
+ gt_multi_idx = indices[batch_idx][1]
160
+ if len(gt_multi_idx) == 0:
161
+ continue
162
+ bz_image_whwh = targets[batch_idx]['image_size_xyxy']
163
+ bz_src_boxes = src_boxes[batch_idx]
164
+ bz_target_boxes = targets[batch_idx]["boxes"] # normalized (cx, cy, w, h)
165
+ bz_target_boxes_xyxy = targets[batch_idx]["boxes_xyxy"] # absolute (x1, y1, x2, y2)
166
+ pred_box_list.append(bz_src_boxes[valid_query])
167
+ pred_norm_box_list.append(bz_src_boxes[valid_query] / bz_image_whwh) # normalize (x1, y1, x2, y2)
168
+ tgt_box_list.append(bz_target_boxes[gt_multi_idx])
169
+ tgt_box_xyxy_list.append(bz_target_boxes_xyxy[gt_multi_idx])
170
+
171
+ if len(pred_box_list) != 0:
172
+ src_boxes = torch.cat(pred_box_list)
173
+ src_boxes_norm = torch.cat(pred_norm_box_list) # normalized (x1, y1, x2, y2)
174
+ target_boxes = torch.cat(tgt_box_list)
175
+ target_boxes_abs_xyxy = torch.cat(tgt_box_xyxy_list)
176
+ num_boxes = src_boxes.shape[0]
177
+
178
+ losses = {}
179
+ # require normalized (x1, y1, x2, y2)
180
+ loss_bbox = F.l1_loss(src_boxes_norm, ops.box_convert(target_boxes, 'cxcywh', 'xyxy'), reduction='none')
181
+ losses['loss_bbox'] = loss_bbox.sum() / num_boxes
182
+
183
+ # loss_giou = giou_loss(box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))
184
+ loss_giou = 1 - torch.diag(ops.generalized_box_iou(src_boxes, target_boxes_abs_xyxy))
185
+ losses['loss_giou'] = loss_giou.sum() / num_boxes
186
+ else:
187
+ losses = {'loss_bbox': outputs['pred_boxes'].sum() * 0,
188
+ 'loss_giou': outputs['pred_boxes'].sum() * 0}
189
+
190
+ return losses
191
+
192
+ def get_loss(self, loss, outputs, targets, indices):
193
+ loss_map = {
194
+ 'labels': self.loss_labels,
195
+ 'boxes': self.loss_boxes,
196
+ }
197
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
198
+ return loss_map[loss](outputs, targets, indices)
199
+
200
+ def forward(self, outputs, targets):
201
+ """ This performs the loss computation.
202
+ Parameters:
203
+ outputs: dict of tensors, see the output specification of the model for the format
204
+ targets: list of dicts, such that len(targets) == batch_size.
205
+ The expected keys in each dict depends on the losses applied, see each loss' doc
206
+ """
207
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
208
+
209
+ # Retrieve the matching between the outputs of the last layer and the targets
210
+ indices, _ = self.matcher(outputs_without_aux, targets)
211
+
212
+ # Compute all the requested losses
213
+ losses = {}
214
+ for loss in ["labels", "boxes"]:
215
+ losses.update(self.get_loss(loss, outputs, targets, indices))
216
+
217
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
218
+ if 'aux_outputs' in outputs:
219
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
220
+ indices, _ = self.matcher(aux_outputs, targets)
221
+ for loss in ["labels", "boxes"]:
222
+ if loss == 'masks':
223
+ # Intermediate masks losses are too costly to compute, we ignore them.
224
+ continue
225
+
226
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices)
227
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
228
+ losses.update(l_dict)
229
+
230
+ return losses
231
+
232
+
233
+ def get_in_boxes_info(boxes, target_gts):
234
+ xy_target_gts = ops.box_convert(target_gts, 'cxcywh', 'xyxy') # (x1, y1, x2, y2)
235
+
236
+ anchor_center_x = boxes[:, 0].unsqueeze(1)
237
+ anchor_center_y = boxes[:, 1].unsqueeze(1)
238
+
239
+ # whether the center of each anchor is inside a gt box
240
+ b_l = anchor_center_x > xy_target_gts[:, 0].unsqueeze(0)
241
+ b_r = anchor_center_x < xy_target_gts[:, 2].unsqueeze(0)
242
+ b_t = anchor_center_y > xy_target_gts[:, 1].unsqueeze(0)
243
+ b_b = anchor_center_y < xy_target_gts[:, 3].unsqueeze(0)
244
+ # (b_l.long()+b_r.long()+b_t.long()+b_b.long())==4 [300,num_gt] ,
245
+ is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4)
246
+ is_in_boxes_all = is_in_boxes.sum(1) > 0 # [num_query]
247
+ # in fixed center
248
+ center_radius = 2.5
249
+ # Modified to self-adapted sampling --- the center size depends on the size of the gt boxes
250
+ # https://github.com/dulucas/UVO_Challenge/blob/main/Track1/detection/mmdet/core/bbox/assigners/rpn_sim_ota_assigner.py#L212
251
+ b_l = anchor_center_x > (
252
+ target_gts[:, 0] - (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0)
253
+ b_r = anchor_center_x < (
254
+ target_gts[:, 0] + (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0)
255
+ b_t = anchor_center_y > (
256
+ target_gts[:, 1] - (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0)
257
+ b_b = anchor_center_y < (
258
+ target_gts[:, 1] + (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0)
259
+
260
+ is_in_centers = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4)
261
+ is_in_centers_all = is_in_centers.sum(1) > 0
262
+
263
+ is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
264
+ is_in_boxes_and_center = (is_in_boxes & is_in_centers)
265
+
266
+ return is_in_boxes_anchor, is_in_boxes_and_center
267
+
268
+
269
+ class HungarianMatcherDynamicK(nn.Module):
270
+ """This class computes an assignment between the targets and the predictions of the network
271
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
272
+ there are more predictions than targets. In this case, we do a 1-to-k (dynamic) matching of the best predictions,
273
+ while the others are un-matched (and thus treated as non-objects).
274
+ """
275
+
276
+ def __init__(self, config):
277
+ super().__init__()
278
+ self.use_focal = config.use_focal
279
+ self.use_fed_loss = config.use_fed_loss
280
+ self.cost_class = config.class_weight
281
+ self.cost_giou = config.giou_weight
282
+ self.cost_bbox = config.l1_weight
283
+ self.ota_k = config.ota_k
284
+
285
+ if self.use_focal:
286
+ self.focal_loss_alpha = config.alpha
287
+ self.focal_loss_gamma = config.gamma
288
+
289
+ assert self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0, "all costs cant be 0"
290
+
291
+ def forward(self, outputs, targets):
292
+ """ simOTA for detr"""
293
+ with torch.no_grad():
294
+ bs, num_queries = outputs["pred_logits"].shape[:2]
295
+ # We flatten to compute the cost matrices in a batch
296
+ if self.use_focal or self.use_fed_loss:
297
+ out_prob = outputs["pred_logits"].sigmoid() # [batch_size, num_queries, num_classes]
298
+ out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4]
299
+ else:
300
+ out_prob = outputs["pred_logits"].softmax(-1) # [batch_size, num_queries, num_classes]
301
+ out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4]
302
+
303
+ indices = []
304
+ matched_ids = []
305
+ assert bs == len(targets)
306
+ for batch_idx in range(bs):
307
+ bz_boxes = out_bbox[batch_idx] # [num_proposals, 4]
308
+ bz_out_prob = out_prob[batch_idx]
309
+ bz_tgt_ids = targets[batch_idx]["labels"]
310
+ num_insts = len(bz_tgt_ids)
311
+ if num_insts == 0: # empty object in key frame
312
+ non_valid = torch.zeros(bz_out_prob.shape[0]).to(bz_out_prob) > 0
313
+ indices_batchi = (non_valid, torch.arange(0, 0).to(bz_out_prob))
314
+ matched_qidx = torch.arange(0, 0).to(bz_out_prob)
315
+ indices.append(indices_batchi)
316
+ matched_ids.append(matched_qidx)
317
+ continue
318
+
319
+ bz_gtboxs = targets[batch_idx]['boxes'] # [num_gt, 4] normalized (cx, xy, w, h)
320
+ bz_gtboxs_abs_xyxy = targets[batch_idx]['boxes_xyxy']
321
+ fg_mask, is_in_boxes_and_center = get_in_boxes_info(
322
+ ops.box_convert(bz_boxes, 'xyxy', 'cxcywh'), # absolute (cx, cy, w, h)
323
+ ops.box_convert(bz_gtboxs_abs_xyxy, 'xyxy', 'cxcywh') # absolute (cx, cy, w, h)
324
+ )
325
+
326
+ pair_wise_ious = ops.box_iou(bz_boxes, bz_gtboxs_abs_xyxy)
327
+
328
+ # Compute the classification cost.
329
+ if self.use_focal:
330
+ alpha = self.focal_loss_alpha
331
+ gamma = self.focal_loss_gamma
332
+ neg_cost_class = (1 - alpha) * (bz_out_prob ** gamma) * (-(1 - bz_out_prob + 1e-8).log())
333
+ pos_cost_class = alpha * ((1 - bz_out_prob) ** gamma) * (-(bz_out_prob + 1e-8).log())
334
+ cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids]
335
+ elif self.use_fed_loss:
336
+ # focal loss degenerates to naive one
337
+ neg_cost_class = (-(1 - bz_out_prob + 1e-8).log())
338
+ pos_cost_class = (-(bz_out_prob + 1e-8).log())
339
+ cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids]
340
+ else:
341
+ cost_class = -bz_out_prob[:, bz_tgt_ids]
342
+
343
+ # Compute the L1 cost between boxes
344
+ # image_size_out = torch.cat([v["image_size_xyxy"].unsqueeze(0) for v in targets])
345
+ # image_size_out = image_size_out.unsqueeze(1).repeat(1, num_queries, 1).flatten(0, 1)
346
+ # image_size_tgt = torch.cat([v["image_size_xyxy_tgt"] for v in targets])
347
+
348
+ bz_image_size_out = targets[batch_idx]['image_size_xyxy']
349
+ bz_image_size_tgt = targets[batch_idx]['image_size_xyxy_tgt']
350
+
351
+ bz_out_bbox_ = bz_boxes / bz_image_size_out # normalize (x1, y1, x2, y2)
352
+ bz_tgt_bbox_ = bz_gtboxs_abs_xyxy / bz_image_size_tgt # normalize (x1, y1, x2, y2)
353
+ cost_bbox = torch.cdist(bz_out_bbox_, bz_tgt_bbox_, p=1)
354
+
355
+ cost_giou = -ops.generalized_box_iou(bz_boxes, bz_gtboxs_abs_xyxy)
356
+
357
+ # Final cost matrix
358
+ cost = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + 100.0 * (
359
+ ~is_in_boxes_and_center)
360
+ # cost = (cost_class + 3.0 * cost_giou + 100.0 * (~is_in_boxes_and_center)) # [num_query,num_gt]
361
+ cost[~fg_mask] = cost[~fg_mask] + 10000.0
362
+
363
+ # if bz_gtboxs.shape[0]>0:
364
+ indices_batchi, matched_qidx = self.dynamic_k_matching(cost, pair_wise_ious, bz_gtboxs.shape[0])
365
+
366
+ indices.append(indices_batchi)
367
+ matched_ids.append(matched_qidx)
368
+
369
+ return indices, matched_ids
370
+
371
+ def dynamic_k_matching(self, cost, pair_wise_ious, num_gt):
372
+ matching_matrix = torch.zeros_like(cost) # [300,num_gt]
373
+ ious_in_boxes_matrix = pair_wise_ious
374
+ n_candidate_k = self.ota_k
375
+
376
+ # Take the sum of the predicted value and the top 10 iou of gt with the largest iou as dynamic_k
377
+ topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=0)
378
+ dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
379
+
380
+ for gt_idx in range(num_gt):
381
+ _, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
382
+ matching_matrix[:, gt_idx][pos_idx] = 1.0
383
+
384
+ del topk_ious, dynamic_ks, pos_idx
385
+
386
+ anchor_matching_gt = matching_matrix.sum(1)
387
+
388
+ if (anchor_matching_gt > 1).sum() > 0:
389
+ _, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1)
390
+ matching_matrix[anchor_matching_gt > 1] *= 0
391
+ matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1
392
+
393
+ while (matching_matrix.sum(0) == 0).any():
394
+ num_zero_gt = (matching_matrix.sum(0) == 0).sum()
395
+ matched_query_id = matching_matrix.sum(1) > 0
396
+ cost[matched_query_id] += 100000.0
397
+ unmatch_id = torch.nonzero(matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1)
398
+ for gt_idx in unmatch_id:
399
+ pos_idx = torch.argmin(cost[:, gt_idx])
400
+ matching_matrix[:, gt_idx][pos_idx] = 1.0
401
+ if (matching_matrix.sum(1) > 1).sum() > 0: # If a query matches more than one gt
402
+ _, cost_argmin = torch.min(cost[anchor_matching_gt > 1],
403
+ dim=1) # find gt for these queries with minimal cost
404
+ matching_matrix[anchor_matching_gt > 1] *= 0 # reset mapping relationship
405
+ matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1 # keep gt with minimal cost
406
+
407
+ assert not (matching_matrix.sum(0) == 0).any()
408
+ selected_query = matching_matrix.sum(1) > 0
409
+ gt_indices = matching_matrix[selected_query].max(1)[1]
410
+ assert selected_query.sum() == len(gt_indices)
411
+
412
+ cost[matching_matrix == 0] = cost[matching_matrix == 0] + float('inf')
413
+ matched_query_id = torch.min(cost, dim=0)[1]
414
+
415
+ return (selected_query, gt_indices), matched_query_id
modeling_diffusiondet.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from collections import namedtuple, OrderedDict
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from torchvision import ops
11
+ from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
12
+ from transformers import PreTrainedModel
13
+ import wandb
14
+
15
+ from transformers.utils.backbone_utils import load_backbone
16
+ from .configuration_diffusiondet import DiffusionDetConfig
17
+
18
+ from .head import HeadDynamicK
19
+ from .loss import CriterionDynamicK
20
+
21
+ from transformers.utils import ModelOutput
22
+
23
+ ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
24
+
25
+
26
+ def default(val, d):
27
+ if val is not None:
28
+ return val
29
+ return d() if callable(d) else d
30
+
31
+
32
+ def extract(a, t, x_shape):
33
+ """extract the appropriate t index for a batch of indices"""
34
+ batch_size = t.shape[0]
35
+ out = a.gather(-1, t)
36
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
37
+
38
+
39
+ def cosine_beta_schedule(timesteps, s=0.008):
40
+ """
41
+ cosine schedule
42
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
43
+ """
44
+ steps = timesteps + 1
45
+ x = torch.linspace(0, timesteps, steps)
46
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
47
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
48
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
49
+ return torch.clip(betas, 0, 0.999)
50
+
51
+ @dataclass
52
+ class DiffusionDetOutput(ModelOutput):
53
+ """
54
+ Output type of DiffusionDet.
55
+ """
56
+
57
+ loss: Optional[torch.FloatTensor] = None
58
+ loss_dict: Optional[Dict] = None
59
+ logits: torch.FloatTensor = None
60
+ labels: torch.IntTensor = None
61
+ pred_boxes: torch.FloatTensor = None
62
+
63
+ class DiffusionDet(PreTrainedModel):
64
+ """
65
+ Implement DiffusionDet
66
+ """
67
+ config_class = DiffusionDetConfig
68
+ main_input_name = "pixel_values"
69
+
70
+ def __init__(self, config):
71
+ super(DiffusionDet, self).__init__(config)
72
+
73
+ self.in_features = config.roi_head_in_features
74
+ self.num_classes = config.num_labels
75
+ self.num_proposals = config.num_proposals
76
+ self.num_heads = config.num_heads
77
+
78
+ self.backbone = load_backbone(config)
79
+ self.fpn = FeaturePyramidNetwork(
80
+ in_channels_list=self.backbone.channels,
81
+ out_channels=config.fpn_out_channels,
82
+ # extra_blocks=LastLevelMaxPool(),
83
+ )
84
+
85
+ # build diffusion
86
+ betas = cosine_beta_schedule(1000)
87
+ alphas_cumprod = torch.cumprod(1 - betas, dim=0)
88
+
89
+ timesteps, = betas.shape
90
+ sampling_timesteps = config.sample_step
91
+
92
+ self.register_buffer('alphas_cumprod', alphas_cumprod)
93
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
94
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
95
+ self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
96
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
97
+
98
+ self.num_timesteps = int(timesteps)
99
+ self.sampling_timesteps = default(sampling_timesteps, timesteps)
100
+ self.ddim_sampling_eta = 1.
101
+ self.scale = config.snr_scale
102
+ assert self.sampling_timesteps <= timesteps
103
+
104
+ roi_input_shape = {
105
+ 'p2': {'stride': 4},
106
+ 'p3': {'stride': 8},
107
+ 'p4': {'stride': 16},
108
+ 'p5': {'stride': 32},
109
+ 'p6': {'stride': 64}
110
+ }
111
+ self.head = HeadDynamicK(config, roi_input_shape=roi_input_shape)
112
+
113
+ self.deep_supervision = config.deep_supervision
114
+ self.use_focal = config.use_focal
115
+ self.use_fed_loss = config.use_fed_loss
116
+ self.use_nms = config.use_nms
117
+
118
+ weight_dict = {
119
+ "loss_ce": config.class_weight, "loss_bbox": config.l1_weight, "loss_giou": config.giou_weight
120
+ }
121
+ if self.deep_supervision:
122
+ aux_weight_dict = {}
123
+ for i in range(self.num_heads - 1):
124
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
125
+ weight_dict.update(aux_weight_dict)
126
+
127
+ self.criterion = CriterionDynamicK(config, num_classes=self.num_classes, weight_dict=weight_dict)
128
+
129
+ def _init_weights(self, module):
130
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
131
+ torch.nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
132
+ if module.bias is not None:
133
+ torch.nn.init.constant_(module.bias, 0)
134
+ elif isinstance(module, nn.BatchNorm2d):
135
+ torch.nn.init.constant_(module.weight, 1)
136
+ torch.nn.init.constant_(module.bias, 0)
137
+
138
+ def predict_noise_from_start(self, x_t, t, x0):
139
+ return (
140
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) /
141
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
142
+ )
143
+
144
+ def model_predictions(self, backbone_feats, images_whwh, x, t):
145
+ x_boxes = torch.clamp(x, min=-1 * self.scale, max=self.scale)
146
+ x_boxes = ((x_boxes / self.scale) + 1) / 2
147
+ x_boxes = ops.box_convert(x_boxes, 'cxcywh', 'xyxy')
148
+ x_boxes = x_boxes * images_whwh[:, None, :]
149
+ outputs_class, outputs_coord = self.head(backbone_feats, x_boxes, t)
150
+
151
+ x_start = outputs_coord[-1] # (batch, num_proposals, 4) predict boxes: absolute coordinates (x1, y1, x2, y2)
152
+ x_start = x_start / images_whwh[:, None, :]
153
+ x_start = ops.box_convert(x_start, 'xyxy', 'cxcywh')
154
+ x_start = (x_start * 2 - 1.) * self.scale
155
+ x_start = torch.clamp(x_start, min=-1 * self.scale, max=self.scale)
156
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
157
+
158
+ return ModelPrediction(pred_noise, x_start), outputs_class, outputs_coord
159
+
160
+ @torch.no_grad()
161
+ def ddim_sample(self, batched_inputs, backbone_feats, images_whwh):
162
+ bs = len(batched_inputs)
163
+ image_sizes = batched_inputs.shape
164
+ shape = (bs, self.num_proposals, 4)
165
+
166
+ # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
167
+ times = torch.linspace(-1, self.num_timesteps - 1, steps=self.sampling_timesteps + 1)
168
+ times = list(reversed(times.int().tolist()))
169
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
170
+
171
+ img = torch.randn(shape, device=self.device)
172
+
173
+ ensemble_score, ensemble_label, ensemble_coord = [], [], []
174
+ outputs_class, outputs_coord = None, None
175
+ for time, time_next in time_pairs:
176
+ time_cond = torch.full((bs,), time, device=self.device, dtype=torch.long)
177
+
178
+ preds, outputs_class, outputs_coord = self.model_predictions(backbone_feats, images_whwh, img, time_cond)
179
+ pred_noise, x_start = preds.pred_noise, preds.pred_x_start
180
+
181
+ score_per_image, box_per_image = outputs_class[-1][0], outputs_coord[-1][0]
182
+ threshold = 0.5
183
+ score_per_image = torch.sigmoid(score_per_image)
184
+ value, _ = torch.max(score_per_image, -1, keepdim=False)
185
+ keep_idx = value > threshold
186
+ num_remain = torch.sum(keep_idx)
187
+
188
+ pred_noise = pred_noise[:, keep_idx, :]
189
+ x_start = x_start[:, keep_idx, :]
190
+ img = img[:, keep_idx, :]
191
+
192
+ if time_next < 0:
193
+ img = x_start
194
+ continue
195
+
196
+ alpha = self.alphas_cumprod[time]
197
+ alpha_next = self.alphas_cumprod[time_next]
198
+
199
+ sigma = self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
200
+ c = (1 - alpha_next - sigma ** 2).sqrt()
201
+
202
+ noise = torch.randn_like(img)
203
+
204
+ img = x_start * alpha_next.sqrt() + \
205
+ c * pred_noise + \
206
+ sigma * noise
207
+
208
+ img = torch.cat((img, torch.randn(1, self.num_proposals - num_remain, 4, device=img.device)), dim=1)
209
+
210
+ if self.sampling_timesteps > 1:
211
+ box_pred_per_image, scores_per_image, labels_per_image = self.inference(outputs_class[-1],
212
+ outputs_coord[-1])
213
+ ensemble_score.append(scores_per_image)
214
+ ensemble_label.append(labels_per_image)
215
+ ensemble_coord.append(box_pred_per_image)
216
+
217
+ if self.sampling_timesteps > 1:
218
+ box_pred_per_image = torch.cat(ensemble_coord, dim=0)
219
+ scores_per_image = torch.cat(ensemble_score, dim=0)
220
+ labels_per_image = torch.cat(ensemble_label, dim=0)
221
+
222
+ if self.use_nms:
223
+ keep = ops.batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
224
+ box_pred_per_image = box_pred_per_image[keep]
225
+ scores_per_image = scores_per_image[keep]
226
+ labels_per_image = labels_per_image[keep]
227
+
228
+ return box_pred_per_image, scores_per_image, labels_per_image
229
+ else:
230
+ return self.inference(outputs_class[-1], outputs_coord[-1])
231
+
232
+ def q_sample(self, x_start, t, noise=None):
233
+ if noise is None:
234
+ noise = torch.randn_like(x_start)
235
+
236
+ sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
237
+ sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
238
+
239
+ return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
240
+
241
+ def forward(self, pixel_values, labels):
242
+ """
243
+ Args:
244
+ """
245
+ images = pixel_values.to(self.device)
246
+ images_whwh = list()
247
+ for image in images:
248
+ h, w = image.shape[-2:]
249
+ images_whwh.append(torch.tensor([w, h, w, h], device=self.device))
250
+ images_whwh = torch.stack(images_whwh)
251
+
252
+ features = self.backbone(images)
253
+ features = OrderedDict(
254
+ [(key, feature) for key, feature in zip(self.backbone.out_features, features.feature_maps)]
255
+ )
256
+ features = self.fpn(features) # [144, 72, 36, 18]
257
+ features = [features[f] for f in features.keys()]
258
+
259
+ # if self.training:
260
+ labels = list(map(lambda tensor: tensor.to(self.device), labels))
261
+ targets, x_boxes, noises, ts = self.prepare_targets(labels)
262
+
263
+ ts = ts.squeeze(-1)
264
+ x_boxes = x_boxes * images_whwh[:, None, :]
265
+
266
+ outputs_class, outputs_coord = self.head(features, x_boxes, ts)
267
+ output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
268
+
269
+ if self.deep_supervision:
270
+ output['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b}
271
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
272
+
273
+ loss_dict = self.criterion(output, targets)
274
+ weight_dict = self.criterion.weight_dict
275
+ for k in loss_dict.keys():
276
+ if k in weight_dict:
277
+ loss_dict[k] *= weight_dict[k]
278
+ loss_dict['loss'] = sum([loss_dict[k] for k in weight_dict.keys()])
279
+
280
+ wandb_logs_values = ["loss_ce", "loss_bbox", "loss_giou"]
281
+
282
+ if self.training:
283
+ wandb.log({f'train/{k}': v.detach().cpu().numpy() for k, v in loss_dict.items() if k in wandb_logs_values})
284
+ else:
285
+ wandb.log({f'eval/{k}': v.detach().cpu().numpy() for k, v in loss_dict.items() if k in wandb_logs_values})
286
+
287
+ if not self.training:
288
+ pred_logits, pred_labels, pred_boxes = self.ddim_sample(pixel_values, features, images_whwh)
289
+ return DiffusionDetOutput(
290
+ loss=loss_dict['loss'],
291
+ loss_dict=loss_dict,
292
+ logits=pred_logits,
293
+ labels=pred_labels,
294
+ pred_boxes=pred_boxes,
295
+ )
296
+
297
+ return DiffusionDetOutput(
298
+ loss=loss_dict['loss'],
299
+ loss_dict=loss_dict,
300
+ logits=output['pred_logits'],
301
+ pred_boxes=output['pred_boxes']
302
+ )
303
+
304
+ def prepare_diffusion_concat(self, gt_boxes):
305
+ """
306
+ :param gt_boxes: (cx, cy, w, h), normalized
307
+ :param num_proposals:
308
+ """
309
+ t = torch.randint(0, self.num_timesteps, (1,), device=self.device).long()
310
+ noise = torch.randn(self.num_proposals, 4, device=self.device)
311
+
312
+ num_gt = gt_boxes.shape[0]
313
+ if not num_gt: # generate fake gt boxes if empty gt boxes
314
+ gt_boxes = torch.as_tensor([[0.5, 0.5, 1., 1.]], dtype=torch.float, device=self.device)
315
+ num_gt = 1
316
+
317
+ if num_gt < self.num_proposals:
318
+ box_placeholder = torch.randn(self.num_proposals - num_gt, 4,
319
+ device=self.device) / 6. + 0.5 # 3sigma = 1/2 --> sigma: 1/6
320
+ box_placeholder[:, 2:] = torch.clip(box_placeholder[:, 2:], min=1e-4)
321
+ x_start = torch.cat((gt_boxes, box_placeholder), dim=0)
322
+ elif num_gt > self.num_proposals:
323
+ select_mask = [True] * self.num_proposals + [False] * (num_gt - self.num_proposals)
324
+ random.shuffle(select_mask)
325
+ x_start = gt_boxes[select_mask]
326
+ else:
327
+ x_start = gt_boxes
328
+
329
+ x_start = (x_start * 2. - 1.) * self.scale
330
+
331
+ # noise sample
332
+ x = self.q_sample(x_start=x_start, t=t, noise=noise)
333
+
334
+ x = torch.clamp(x, min=-1 * self.scale, max=self.scale)
335
+ x = ((x / self.scale) + 1) / 2.
336
+
337
+ diff_boxes = ops.box_convert(x, 'cxcywh', 'xyxy')
338
+
339
+ return diff_boxes, noise, t
340
+
341
+ def prepare_targets(self, targets):
342
+ new_targets = []
343
+ diffused_boxes = []
344
+ noises = []
345
+ ts = []
346
+ for target in targets:
347
+ h, w = target.size
348
+ image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
349
+ gt_classes = target.class_labels.to(self.device)
350
+ gt_boxes = target.boxes.to(self.device)
351
+ d_boxes, d_noise, d_t = self.prepare_diffusion_concat(gt_boxes)
352
+ image_size_xyxy_tgt = image_size_xyxy.unsqueeze(0).repeat(len(gt_boxes), 1)
353
+ gt_boxes = gt_boxes * image_size_xyxy
354
+ gt_boxes = ops.box_convert(gt_boxes, 'cxcywh', 'xyxy')
355
+
356
+ diffused_boxes.append(d_boxes)
357
+ noises.append(d_noise)
358
+ ts.append(d_t)
359
+ new_targets.append({
360
+ "labels": gt_classes,
361
+ "boxes": target.boxes.to(self.device),
362
+ "boxes_xyxy": gt_boxes,
363
+ "image_size_xyxy": image_size_xyxy.to(self.device),
364
+ "image_size_xyxy_tgt": image_size_xyxy_tgt.to(self.device),
365
+ "area": ops.box_area(target.boxes.to(self.device)),
366
+ })
367
+
368
+ return new_targets, torch.stack(diffused_boxes), torch.stack(noises), torch.stack(ts)
369
+
370
+ def inference(self, box_cls, box_pred):
371
+ """
372
+ Arguments:
373
+ box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
374
+ The tensor predicts the classification probability for each proposal.
375
+ box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
376
+ The tensor predicts 4-vector (x,y,w,h) box
377
+ regression values for every proposal
378
+ image_sizes (List[torch.Size]): the input image sizes
379
+
380
+ Returns:
381
+ results (List[Instances]): a list of #images elements.
382
+ """
383
+ results = []
384
+ boxes_output = []
385
+ logits_output = []
386
+ labels_output = []
387
+
388
+ if self.use_focal or self.use_fed_loss:
389
+ scores = torch.sigmoid(box_cls)
390
+ labels = torch.arange(self.num_classes, device=self.device). \
391
+ unsqueeze(0).repeat(self.num_proposals, 1).flatten(0, 1)
392
+
393
+ for i, (scores_per_image, box_pred_per_image) in enumerate(zip(
394
+ scores, box_pred
395
+ )):
396
+ scores_per_image, topk_indices = scores_per_image.flatten(0, 1).topk(self.num_proposals, sorted=False)
397
+ labels_per_image = labels[topk_indices]
398
+ box_pred_per_image = box_pred_per_image.view(-1, 1, 4).repeat(1, self.num_classes, 1).view(-1, 4)
399
+ box_pred_per_image = box_pred_per_image[topk_indices]
400
+
401
+ if self.sampling_timesteps > 1:
402
+ return box_pred_per_image, scores_per_image, labels_per_image
403
+
404
+ if self.use_nms:
405
+ keep = ops.batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
406
+ box_pred_per_image = box_pred_per_image[keep]
407
+ scores_per_image = scores_per_image[keep]
408
+ labels_per_image = labels_per_image[keep]
409
+
410
+ boxes_output.append(box_pred_per_image)
411
+ logits_output.append(scores_per_image)
412
+ labels_output.append(labels_per_image)
413
+ else:
414
+ # For each box we assign the best class or the second best if the best on is `no_object`.
415
+ scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1)
416
+
417
+ for i, (scores_per_image, labels_per_image, box_pred_per_image) in enumerate(zip(
418
+ scores, labels, box_pred
419
+ )):
420
+ if self.sampling_timesteps > 1:
421
+ return box_pred_per_image, scores_per_image, labels_per_image
422
+
423
+ if self.use_nms:
424
+ keep = ops.batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
425
+ box_pred_per_image = box_pred_per_image[keep]
426
+ scores_per_image = scores_per_image[keep]
427
+ labels_per_image = labels_per_image[keep]
428
+
429
+ boxes_output.append(box_pred_per_image)
430
+ logits_output.append(scores_per_image)
431
+ labels_output.append(labels_per_image)
432
+
433
+ return boxes_output, logits_output, labels_output
preprocessor_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_diffusiondet.DiffusionDetImageProcessor"
4
+ },
5
+ "do_convert_annotations": true,
6
+ "do_normalize": true,
7
+ "do_pad": true,
8
+ "do_rescale": true,
9
+ "do_resize": true,
10
+ "format": "coco_detection",
11
+ "image_mean": [
12
+ 0.485,
13
+ 0.456,
14
+ 0.406
15
+ ],
16
+ "image_processor_type": "DiffusionDetImageProcessor",
17
+ "image_std": [
18
+ 0.229,
19
+ 0.224,
20
+ 0.225
21
+ ],
22
+ "pad_size": null,
23
+ "resample": 2,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "longest_edge": 1333,
27
+ "shortest_edge": 800
28
+ }
29
+ }