Spaces:
Paused
Paused
| from starvector.metrics.util import AverageMeter | |
| from tqdm import tqdm | |
| import math | |
| class BaseMetric: | |
| def __init__(self): | |
| self.meter = AverageMeter() | |
| def reset(self): | |
| self.meter.reset() | |
| def calculate_score(self, batch, update=True): | |
| """ | |
| Batch: {"gt_im": [PIL Image], "gen_im": [Image]} | |
| """ | |
| values = [] | |
| batch_size = len(next(iter(batch.values()))) | |
| for index in tqdm(range(batch_size)): | |
| kwargs = {} | |
| for key in ["gt_im", "gen_im", "gt_svg", "gen_svg", "caption"]: | |
| if key in batch: | |
| kwargs[key] = batch[key][index] | |
| try: | |
| measure = self.metric(**kwargs) | |
| except Exception as e: | |
| print("Error calculating metric: {}".format(e)) | |
| continue | |
| if math.isnan(measure): | |
| continue | |
| values.append(measure) | |
| if not values: | |
| print("No valid values found for metric calculation.") | |
| return float("nan") | |
| score = sum(values) / len(values) | |
| if update: | |
| self.meter.update(score, len(values)) | |
| return self.meter.avg, values | |
| else: | |
| return score, values | |
| def metric(self, **kwargs): | |
| """ | |
| This method should be overridden by subclasses to provide the specific metric computation. | |
| """ | |
| raise NotImplementedError("The metric method must be implemented by subclasses.") | |
| def get_average_score(self): | |
| return self.meter.avg | |