|
import torch |
|
|
|
def standard_loss(self, model, inputs, return_outputs=False): |
|
outputs = model(**inputs) |
|
if self.args.past_index >= 0: |
|
self._past = outputs[self.args.past_index] |
|
loss = outputs[0].mean() |
|
print("loss : ", loss) |
|
return (loss,outputs) if return_outputs else loss |
|
|
|
def hinge_cost(m, a, b): |
|
dist = m - torch.sqrt(torch.sum((a - b)**2, axis=1)) |
|
return torch.mean(torch.clamp(dist,0,float('inf'))**2) |
|
|
|
def sim_metric_loss(self, model, inputs, return_outputs=False): |
|
|
|
|
|
batch_size = len(inputs["input_ids"])//4 |
|
outputs = model(**inputs) |
|
x_p = outputs[:batch_size] |
|
x_n = outputs[batch_size:2*batch_size] |
|
y_p = outputs[2*batch_size:3*batch_size] |
|
y_n = outputs[3*batch_size:] |
|
|
|
model_attr = model |
|
if isinstance(model, torch.nn.DataParallel): |
|
model_attr = model.module |
|
|
|
cost_p = torch.mean(torch.sum((x_p - y_p)**2, axis=1)) |
|
cost_n = model_attr.negative_importance*hinge_cost( |
|
model_attr.negative_threshold, x_n, y_n) |
|
cost_e = model_attr.entropy_importance*torch.mean( |
|
torch.sum(x_p**2, axis=1) + torch.sum(y_p**2, axis=1)) |
|
loss = cost_p + cost_n + cost_e |
|
|
|
print(loss) |
|
return (loss,None) if return_outputs else loss |
|
|