Spaces:
Configuration error
Configuration error
| import abc | |
| from copy import deepcopy | |
| import cv2 | |
| import numpy as np | |
| from sklearn.decomposition import PCA | |
| from typing_extensions import Protocol | |
| class TransformerInterface(Protocol): | |
| def inverse_transform(self, X: np.ndarray) -> np.ndarray: | |
| ... | |
| def fit(self, X: np.ndarray, y=None): | |
| ... | |
| def transform(self, X: np.ndarray, y=None) -> np.ndarray: | |
| ... | |
| class DomainAdapter: | |
| def __init__(self, | |
| transformer: TransformerInterface, | |
| ref_img: np.ndarray, | |
| color_conversions=(None, None), | |
| ): | |
| self.color_in, self.color_out = color_conversions | |
| self.source_transformer = deepcopy(transformer) | |
| self.target_transformer = transformer | |
| self.target_transformer.fit(self.flatten(ref_img)) | |
| def to_colorspace(self, img): | |
| if self.color_in is None: | |
| return img | |
| return cv2.cvtColor(img, self.color_in) | |
| def from_colorspace(self, img): | |
| if self.color_out is None: | |
| return img | |
| return cv2.cvtColor(img.astype('uint8'), self.color_out) | |
| def flatten(self, img): | |
| img = self.to_colorspace(img) | |
| img = img.astype('float32') / 255. | |
| return img.reshape(-1, 3) | |
| def reconstruct(self, pixels, h, w): | |
| pixels = (np.clip(pixels, 0, 1) * 255).astype('uint8') | |
| return self.from_colorspace(pixels.reshape(h, w, 3)) | |
| def _pca_sign(x): | |
| return np.sign(np.trace(x.components_)) | |
| def __call__(self, image: np.ndarray): | |
| h, w, _ = image.shape | |
| pixels = self.flatten(image) | |
| self.source_transformer.fit(pixels) | |
| if self.target_transformer.__class__ in (PCA,): | |
| # dirty hack to make sure colors are not inverted | |
| if self._pca_sign(self.target_transformer) != self._pca_sign(self.source_transformer): | |
| self.target_transformer.components_ *= -1 | |
| representation = self.source_transformer.transform(pixels) | |
| result = self.target_transformer.inverse_transform(representation) | |
| return self.reconstruct(result, h, w) | |