Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +69 -40
- config.json +41 -5
- config/diffsketcher-color.yaml +75 -0
- config/diffsketcher-style.yaml +78 -0
- config/diffsketcher-width.yaml +75 -0
- config/diffsketcher.yaml +76 -0
- handler.py +132 -111
- libs/__init__.py +15 -0
- libs/engine/__init__.py +12 -0
- libs/engine/config_processor.py +156 -0
- libs/engine/model_state.py +339 -0
- libs/metric/__init__.py +4 -0
- libs/metric/accuracy.py +31 -0
- libs/metric/clip_score/__init__.py +8 -0
- libs/metric/clip_score/openaiCLIP_loss.py +305 -0
- libs/metric/lpips_origin/__init__.py +3 -0
- libs/metric/lpips_origin/lpips.py +184 -0
- libs/metric/lpips_origin/pretrained_networks.py +196 -0
- libs/metric/lpips_origin/weights/v0.1/alex.pth +3 -0
- libs/metric/lpips_origin/weights/v0.1/squeeze.pth +3 -0
- libs/metric/lpips_origin/weights/v0.1/vgg.pth +3 -0
- libs/metric/piq/__init__.py +7 -0
- libs/metric/piq/functional/__init__.py +15 -0
- libs/metric/piq/functional/base.py +111 -0
- libs/metric/piq/functional/colour_conversion.py +136 -0
- libs/metric/piq/functional/filters.py +111 -0
- libs/metric/piq/functional/layers.py +33 -0
- libs/metric/piq/functional/resize.py +426 -0
- libs/metric/piq/perceptual.py +496 -0
- libs/metric/piq/utils/__init__.py +7 -0
- libs/metric/piq/utils/common.py +158 -0
- libs/metric/pytorch_fid/__init__.py +54 -0
- libs/metric/pytorch_fid/fid_score.py +322 -0
- libs/metric/pytorch_fid/inception.py +341 -0
- libs/modules/__init__.py +4 -0
- libs/modules/edge_map/DoG/XDoG.py +78 -0
- libs/modules/edge_map/DoG/__init__.py +8 -0
- libs/modules/edge_map/__init__.py +4 -0
- libs/modules/edge_map/canny/__init__.py +15 -0
- libs/modules/edge_map/image_grads/__init__.py +8 -0
- libs/modules/edge_map/image_grads/laplacian.py +13 -0
- libs/modules/ema.py +198 -0
- libs/modules/vision/__init__.py +12 -0
- libs/modules/vision/inception.py +482 -0
- libs/modules/vision/vgg.py +194 -0
- libs/modules/visual/__init__.py +4 -0
- libs/modules/visual/imshow.py +177 -0
- libs/modules/visual/video.py +38 -0
- libs/solver/__init__.py +4 -0
- libs/solver/lr_scheduler.py +350 -0
README.md
CHANGED
|
@@ -1,72 +1,101 @@
|
|
| 1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
tags:
|
| 3 |
-
- text-to-image
|
| 4 |
-
- diffusers
|
| 5 |
-
- vector-graphics
|
| 6 |
- svg
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
# DiffSketcher
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
## Model Description
|
| 17 |
|
| 18 |
-
DiffSketcher
|
| 19 |
|
| 20 |
## Usage
|
| 21 |
|
| 22 |
```python
|
| 23 |
import requests
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
return response.content
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
API_URL,
|
| 42 |
-
headers=headers,
|
| 43 |
-
json={
|
| 44 |
-
"inputs": {
|
| 45 |
-
"text": "a beautiful mountain landscape",
|
| 46 |
-
"width": 512,
|
| 47 |
-
"height": 512,
|
| 48 |
-
"num_paths": 512,
|
| 49 |
-
"seed": 42
|
| 50 |
-
}
|
| 51 |
-
}
|
| 52 |
-
)
|
| 53 |
```
|
| 54 |
|
| 55 |
## Parameters
|
| 56 |
|
| 57 |
-
-
|
| 58 |
-
-
|
| 59 |
-
-
|
| 60 |
-
-
|
| 61 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
## Citation
|
| 64 |
|
| 65 |
```bibtex
|
| 66 |
@inproceedings{xing2023diffsketcher,
|
| 67 |
title={DiffSketcher: Text Guided Vector Sketch Synthesis through Latent Diffusion Models},
|
| 68 |
-
author={Xing, XiMing and
|
| 69 |
booktitle={Advances in Neural Information Processing Systems},
|
| 70 |
year={2023}
|
| 71 |
}
|
| 72 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: DiffSketcher
|
| 3 |
+
emoji: 🎨
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: custom
|
| 7 |
+
app_file: handler.py
|
| 8 |
+
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
tags:
|
|
|
|
|
|
|
|
|
|
| 11 |
- svg
|
| 12 |
+
- vector-graphics
|
| 13 |
+
- text-to-image
|
| 14 |
+
- diffusion
|
| 15 |
+
- sketch
|
| 16 |
+
pipeline_tag: image-generation
|
| 17 |
+
library_name: diffvg
|
| 18 |
---
|
| 19 |
|
| 20 |
+
# DiffSketcher: Text Guided Vector Sketch Synthesis
|
| 21 |
|
| 22 |
+
DiffSketcher is a novel method for generating high-quality vector sketches from text prompts using latent diffusion models. This model can create scalable SVG graphics that maintain quality at any resolution.
|
| 23 |
|
| 24 |
## Model Description
|
| 25 |
|
| 26 |
+
DiffSketcher leverages the power of Stable Diffusion to guide the optimization of vector paths, creating artistic sketches that are both semantically meaningful and visually appealing. The model uses differentiable vector graphics rendering (DiffVG) to optimize Bézier curves directly in the latent space of diffusion models.
|
| 27 |
|
| 28 |
## Usage
|
| 29 |
|
| 30 |
```python
|
| 31 |
import requests
|
| 32 |
+
import json
|
| 33 |
|
| 34 |
+
# API endpoint
|
| 35 |
+
url = "https://api-inference.huggingface.co/models/jree423/diffsketcher"
|
| 36 |
|
| 37 |
+
# Headers
|
| 38 |
+
headers = {"Authorization": "Bearer YOUR_HF_TOKEN"}
|
|
|
|
| 39 |
|
| 40 |
+
# Payload
|
| 41 |
+
payload = {
|
| 42 |
+
"inputs": "a beautiful mountain landscape",
|
| 43 |
+
"parameters": {
|
| 44 |
+
"num_paths": 96,
|
| 45 |
+
"num_iter": 500,
|
| 46 |
+
"token_ind": 4,
|
| 47 |
+
"guidance_scale": 7.5,
|
| 48 |
+
"canvas_size": 224
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
|
| 52 |
+
# Make request
|
| 53 |
+
response = requests.post(url, headers=headers, json=payload)
|
| 54 |
+
result = response.json()
|
| 55 |
|
| 56 |
+
# The result contains the SVG content
|
| 57 |
+
svg_content = result[0]["svg"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
```
|
| 59 |
|
| 60 |
## Parameters
|
| 61 |
|
| 62 |
+
- **num_paths** (int, default: 96): Number of paths/strokes in the generated SVG
|
| 63 |
+
- **num_iter** (int, default: 500): Number of optimization iterations
|
| 64 |
+
- **token_ind** (int, default: 4): Index of cross-attention maps to initialize strokes
|
| 65 |
+
- **guidance_scale** (float, default: 7.5): Guidance scale for diffusion
|
| 66 |
+
- **canvas_size** (int, default: 224): Canvas size for SVG generation
|
| 67 |
+
|
| 68 |
+
## Examples
|
| 69 |
+
|
| 70 |
+
### Simple Sketch
|
| 71 |
+
```
|
| 72 |
+
Input: "a cat sitting on a chair"
|
| 73 |
+
Parameters: {"num_paths": 48, "num_iter": 300}
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### Detailed Artwork
|
| 77 |
+
```
|
| 78 |
+
Input: "a majestic eagle soaring through clouds"
|
| 79 |
+
Parameters: {"num_paths": 128, "num_iter": 800}
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### Abstract Art
|
| 83 |
+
```
|
| 84 |
+
Input: "abstract geometric patterns in blue and gold"
|
| 85 |
+
Parameters: {"num_paths": 200, "num_iter": 1000}
|
| 86 |
+
```
|
| 87 |
|
| 88 |
## Citation
|
| 89 |
|
| 90 |
```bibtex
|
| 91 |
@inproceedings{xing2023diffsketcher,
|
| 92 |
title={DiffSketcher: Text Guided Vector Sketch Synthesis through Latent Diffusion Models},
|
| 93 |
+
author={Xing, XiMing and Wang, Chuang and Zhou, Haitao and Zhang, Jing and Yu, Qian and Xu, Dong},
|
| 94 |
booktitle={Advances in Neural Information Processing Systems},
|
| 95 |
year={2023}
|
| 96 |
}
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
## License
|
| 100 |
+
|
| 101 |
+
This model is released under the MIT License.
|
config.json
CHANGED
|
@@ -1,8 +1,44 @@
|
|
| 1 |
{
|
| 2 |
-
"architectures": [
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
],
|
| 5 |
-
"
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"architectures": ["DiffSketcher"],
|
| 3 |
+
"model_type": "diffsketcher",
|
| 4 |
+
"task": "text-to-svg",
|
| 5 |
+
"framework": "pytorch",
|
| 6 |
+
"pipeline_tag": "image-generation",
|
| 7 |
+
"library_name": "diffvg",
|
| 8 |
+
"tags": [
|
| 9 |
+
"svg",
|
| 10 |
+
"vector-graphics",
|
| 11 |
+
"text-to-image",
|
| 12 |
+
"diffusion",
|
| 13 |
+
"sketch"
|
| 14 |
],
|
| 15 |
+
"inference": {
|
| 16 |
+
"parameters": {
|
| 17 |
+
"num_paths": {
|
| 18 |
+
"type": "integer",
|
| 19 |
+
"default": 96,
|
| 20 |
+
"description": "Number of paths/strokes in the generated SVG"
|
| 21 |
+
},
|
| 22 |
+
"num_iter": {
|
| 23 |
+
"type": "integer",
|
| 24 |
+
"default": 500,
|
| 25 |
+
"description": "Number of optimization iterations"
|
| 26 |
+
},
|
| 27 |
+
"token_ind": {
|
| 28 |
+
"type": "integer",
|
| 29 |
+
"default": 4,
|
| 30 |
+
"description": "Index of cross-attention maps to initialize strokes"
|
| 31 |
+
},
|
| 32 |
+
"guidance_scale": {
|
| 33 |
+
"type": "float",
|
| 34 |
+
"default": 7.5,
|
| 35 |
+
"description": "Guidance scale for diffusion"
|
| 36 |
+
},
|
| 37 |
+
"canvas_size": {
|
| 38 |
+
"type": "integer",
|
| 39 |
+
"default": 224,
|
| 40 |
+
"description": "Canvas size for SVG generation"
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
}
|
config/diffsketcher-color.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_size: 224
|
| 2 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
| 3 |
+
mask_object: False # if the target image contains background, it's better to mask it out
|
| 4 |
+
fix_scale: False # if the target image is not squared, it is recommended to fix the scale
|
| 5 |
+
|
| 6 |
+
# train
|
| 7 |
+
num_iter: 2000
|
| 8 |
+
batch_size: 1
|
| 9 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
| 10 |
+
lr_scheduler: False
|
| 11 |
+
lr_decay_rate: 0.1
|
| 12 |
+
decay_steps: [ 1000, 1500 ]
|
| 13 |
+
lr: 1 # point lr
|
| 14 |
+
color_lr: 0.01
|
| 15 |
+
color_vars_threshold: 0.1
|
| 16 |
+
width_lr: 0.1 # stroke width lr
|
| 17 |
+
max_width: 50 # stroke width
|
| 18 |
+
|
| 19 |
+
# stroke attrs
|
| 20 |
+
num_paths: 128 # number of strokes
|
| 21 |
+
width: 1.5 # init stroke width
|
| 22 |
+
control_points_per_seg: 4
|
| 23 |
+
num_segments: 1
|
| 24 |
+
optim_opacity: False # if True, the stroke opacity is optimized
|
| 25 |
+
optim_width: True # if True, the stroke width is optimized
|
| 26 |
+
optim_rgba: True # if True, the stroke RGBA is optimized
|
| 27 |
+
opacity_delta: 0 # stroke pruning
|
| 28 |
+
|
| 29 |
+
# init strokes
|
| 30 |
+
attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
|
| 31 |
+
xdog_intersec: False # initialize along the edge, mix XDoG and attn up
|
| 32 |
+
softmax_temp: 0.5 # the temperature of softmax
|
| 33 |
+
cross_attn_res: 16 # cross attn resolution
|
| 34 |
+
self_attn_res: 32 # self-attn resolution
|
| 35 |
+
max_com: 20 # select the number of the self-attn maps
|
| 36 |
+
mean_comp: False # the average of the self-attn maps
|
| 37 |
+
comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
|
| 38 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
| 39 |
+
log_cross_attn: False
|
| 40 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
| 41 |
+
|
| 42 |
+
# ldm
|
| 43 |
+
model_id: "sd15" # stable diffusion V1.5
|
| 44 |
+
ldm_speed_up: False
|
| 45 |
+
enable_xformers: True # speed up attn compute
|
| 46 |
+
gradient_checkpoint: False # this slows down the code, but saves GPU VRAM
|
| 47 |
+
token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token
|
| 48 |
+
use_ddim: True
|
| 49 |
+
num_inference_steps: 100
|
| 50 |
+
guidance_scale: 7.5
|
| 51 |
+
|
| 52 |
+
# ASDS loss
|
| 53 |
+
sds:
|
| 54 |
+
crop_size: 512
|
| 55 |
+
augmentations: "affine"
|
| 56 |
+
guidance_scale: 100
|
| 57 |
+
grad_scale: 1e-6
|
| 58 |
+
t_range: [ 0.05, 0.95 ]
|
| 59 |
+
warmup: 3000
|
| 60 |
+
|
| 61 |
+
# JVSP
|
| 62 |
+
clip:
|
| 63 |
+
model_name: "RN101" # RN101, ViT-L/14
|
| 64 |
+
feats_loss_type: "l2" # clip visual loss type, conv layers
|
| 65 |
+
feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
|
| 66 |
+
# feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
|
| 67 |
+
fc_loss_weight: 0.1 # clip visual loss, fc layer weight
|
| 68 |
+
augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial
|
| 69 |
+
num_aug: 4 # num of augmentation before clip visual computation
|
| 70 |
+
vis_loss: 1 # 1 or 0 for use or disable clip visual loss
|
| 71 |
+
text_visual_coeff: 0 # cosine similarity between text and img
|
| 72 |
+
perceptual:
|
| 73 |
+
name: "lpips" # dists
|
| 74 |
+
lpips_net: 'vgg'
|
| 75 |
+
coeff: 0.2
|
config/diffsketcher-style.yaml
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_size: 224
|
| 2 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
| 3 |
+
mask_object: False # if the target image contains background, it's better to mask it out
|
| 4 |
+
fix_scale: False # if the target image is not squared, it is recommended to fix the scale
|
| 5 |
+
|
| 6 |
+
# train
|
| 7 |
+
num_iter: 2000
|
| 8 |
+
batch_size: 1
|
| 9 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
| 10 |
+
lr_scheduler: False
|
| 11 |
+
lr_decay_rate: 0.1
|
| 12 |
+
decay_steps: [ 1000, 1500 ]
|
| 13 |
+
lr: 1 # point lr
|
| 14 |
+
color_lr: 0.01
|
| 15 |
+
color_vars_threshold: 0.0 # uncomment the code
|
| 16 |
+
width_lr: 0.1 # stroke width lr
|
| 17 |
+
max_width: 50 # stroke width
|
| 18 |
+
|
| 19 |
+
# stroke attrs
|
| 20 |
+
num_paths: 2500 # number of strokes
|
| 21 |
+
width: 1.5 # init stroke width
|
| 22 |
+
control_points_per_seg: 4
|
| 23 |
+
num_segments: 1
|
| 24 |
+
optim_opacity: True # if True, the stroke opacity is optimized
|
| 25 |
+
optim_width: True # if True, the stroke width is optimized
|
| 26 |
+
optim_rgba: True # if True, the stroke RGBA is optimized
|
| 27 |
+
opacity_delta: 0 # stroke pruning
|
| 28 |
+
|
| 29 |
+
# init strokes
|
| 30 |
+
attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
|
| 31 |
+
xdog_intersec: False # initialize along the edge, mix XDoG and attn up
|
| 32 |
+
softmax_temp: 0.5 # the temperature of softmax
|
| 33 |
+
cross_attn_res: 16 # cross attn resolution
|
| 34 |
+
self_attn_res: 32 # self-attn resolution
|
| 35 |
+
max_com: 20 # select the number of the self-attn maps
|
| 36 |
+
mean_comp: False # the average of the self-attn maps
|
| 37 |
+
comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
|
| 38 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
| 39 |
+
log_cross_attn: False
|
| 40 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
| 41 |
+
|
| 42 |
+
# ldm
|
| 43 |
+
model_id: "sd15" # stable diffusion V1.5
|
| 44 |
+
ldm_speed_up: False
|
| 45 |
+
enable_xformers: True # speed up attn compute
|
| 46 |
+
gradient_checkpoint: False # this slows down the code, but saves GPU VRAM
|
| 47 |
+
token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token
|
| 48 |
+
use_ddim: True
|
| 49 |
+
num_inference_steps: 100
|
| 50 |
+
guidance_scale: 7.5
|
| 51 |
+
|
| 52 |
+
# ASDS loss
|
| 53 |
+
sds:
|
| 54 |
+
crop_size: 512
|
| 55 |
+
augmentations: "affine"
|
| 56 |
+
guidance_scale: 100
|
| 57 |
+
grad_scale: 1e-6
|
| 58 |
+
t_range: [ 0.05, 0.95 ]
|
| 59 |
+
warmup: 3000
|
| 60 |
+
|
| 61 |
+
# JVSP
|
| 62 |
+
clip:
|
| 63 |
+
model_name: "RN101" # RN101, ViT-L/14
|
| 64 |
+
feats_loss_type: "l2" # clip visual loss type, conv layers
|
| 65 |
+
feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
|
| 66 |
+
# feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
|
| 67 |
+
fc_loss_weight: 0.1 # clip visual loss, fc layer weight
|
| 68 |
+
augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial
|
| 69 |
+
num_aug: 4 # num of augmentation before clip visual computation
|
| 70 |
+
vis_loss: 1 # 1 or 0 for use or disable clip visual loss
|
| 71 |
+
text_visual_coeff: 0 # cosine similarity between text and img
|
| 72 |
+
perceptual:
|
| 73 |
+
name: "lpips" # dists
|
| 74 |
+
lpips_net: 'vgg'
|
| 75 |
+
coeff: 0.2
|
| 76 |
+
|
| 77 |
+
style_warmup: 1000 # add style loss after `style_warmup` step
|
| 78 |
+
style_strength: 1 # How strong the style should be. 100 (max) is a lot. 0 (min) is no style.
|
config/diffsketcher-width.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_size: 224
|
| 2 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
| 3 |
+
mask_object: False # if the target image contains background, it's better to mask it out
|
| 4 |
+
fix_scale: False # if the target image is not squared, it is recommended to fix the scale
|
| 5 |
+
|
| 6 |
+
# train
|
| 7 |
+
num_iter: 500
|
| 8 |
+
batch_size: 1
|
| 9 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
| 10 |
+
lr_scheduler: False
|
| 11 |
+
lr_decay_rate: 0.1
|
| 12 |
+
decay_steps: [ 1000, 1500 ]
|
| 13 |
+
lr: 1 # point lr
|
| 14 |
+
color_lr: 0.01
|
| 15 |
+
color_vars_threshold: 0.1
|
| 16 |
+
width_lr: 0.1 # stroke width lr
|
| 17 |
+
max_width: 50 # stroke width
|
| 18 |
+
|
| 19 |
+
# stroke attrs
|
| 20 |
+
num_paths: 128 # number of strokes
|
| 21 |
+
width: 3 # init stroke width
|
| 22 |
+
control_points_per_seg: 4
|
| 23 |
+
num_segments: 1
|
| 24 |
+
optim_opacity: True # if True, the stroke opacity is optimized
|
| 25 |
+
optim_width: True # if True, the stroke width is optimized
|
| 26 |
+
optim_rgba: False # if True, the stroke RGBA is optimized
|
| 27 |
+
opacity_delta: 0 # stroke pruning
|
| 28 |
+
|
| 29 |
+
# init strokes
|
| 30 |
+
attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
|
| 31 |
+
xdog_intersec: True # initialize along the edge, mix XDoG and attn up
|
| 32 |
+
softmax_temp: 0.5 # the temperature of softmax
|
| 33 |
+
cross_attn_res: 16 # cross attn resolution
|
| 34 |
+
self_attn_res: 32 # self-attn resolution
|
| 35 |
+
max_com: 20 # select the number of the self-attn maps
|
| 36 |
+
mean_comp: False # the average of the self-attn maps
|
| 37 |
+
comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
|
| 38 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
| 39 |
+
log_cross_attn: False
|
| 40 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
| 41 |
+
|
| 42 |
+
# ldm
|
| 43 |
+
model_id: "sd15" # stable diffusion V1.5
|
| 44 |
+
ldm_speed_up: False
|
| 45 |
+
enable_xformers: True # speed up attn compute
|
| 46 |
+
gradient_checkpoint: False # this slows down the code, but saves GPU VRAM
|
| 47 |
+
token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token
|
| 48 |
+
use_ddim: True
|
| 49 |
+
num_inference_steps: 100
|
| 50 |
+
guidance_scale: 7.5
|
| 51 |
+
|
| 52 |
+
# ASDS loss
|
| 53 |
+
sds:
|
| 54 |
+
crop_size: 512
|
| 55 |
+
augmentations: "affine"
|
| 56 |
+
guidance_scale: 100
|
| 57 |
+
grad_scale: 1e-5
|
| 58 |
+
t_range: [ 0.05, 0.95 ]
|
| 59 |
+
warmup: 2000
|
| 60 |
+
|
| 61 |
+
# JVSP
|
| 62 |
+
clip:
|
| 63 |
+
model_name: "RN101" # RN101, ViT-L/14
|
| 64 |
+
feats_loss_type: "l2" # clip visual loss type, conv layers
|
| 65 |
+
feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
|
| 66 |
+
# feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
|
| 67 |
+
fc_loss_weight: 0.1 # clip visual loss, fc layer weight
|
| 68 |
+
augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial
|
| 69 |
+
num_aug: 4 # num of augmentation before clip visual computation
|
| 70 |
+
vis_loss: 1 # 1 or 0 for use or disable clip visual loss
|
| 71 |
+
text_visual_coeff: 0 # cosine similarity between text and img
|
| 72 |
+
perceptual:
|
| 73 |
+
name: "lpips" # dists
|
| 74 |
+
lpips_net: 'vgg'
|
| 75 |
+
coeff: 0.2
|
config/diffsketcher.yaml
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_size: 224
|
| 2 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
| 3 |
+
mask_object: False # if the target image contains background, it's better to mask it out
|
| 4 |
+
fix_scale: False # if the target image is not squared, it is recommended to fix the scale
|
| 5 |
+
|
| 6 |
+
# train
|
| 7 |
+
num_iter: 2000
|
| 8 |
+
batch_size: 1
|
| 9 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
| 10 |
+
lr_scheduler: False
|
| 11 |
+
lr_decay_rate: 0.1
|
| 12 |
+
decay_steps: [ 1000, 1500 ]
|
| 13 |
+
lr: 1 # point lr
|
| 14 |
+
color_lr: 0.01
|
| 15 |
+
pruning_freq: 50
|
| 16 |
+
color_vars_threshold: 0.1
|
| 17 |
+
width_lr: 0.1
|
| 18 |
+
max_width: 50 # stroke width
|
| 19 |
+
|
| 20 |
+
# stroke attrs
|
| 21 |
+
num_paths: 128 # number of strokes
|
| 22 |
+
width: 1.5 # stroke width
|
| 23 |
+
control_points_per_seg: 4
|
| 24 |
+
num_segments: 1
|
| 25 |
+
optim_opacity: True # if True, the stroke opacity is optimized
|
| 26 |
+
optim_width: False # if True, the stroke width is optimized
|
| 27 |
+
optim_rgba: False # if True, the stroke RGBA is optimized
|
| 28 |
+
opacity_delta: 0 # stroke pruning
|
| 29 |
+
|
| 30 |
+
# init strokes
|
| 31 |
+
attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
|
| 32 |
+
xdog_intersec: True # initialize along the edge, mix XDoG and attn up
|
| 33 |
+
softmax_temp: 0.5 # the temperature of softmax
|
| 34 |
+
cross_attn_res: 16 # cross attn resolution
|
| 35 |
+
self_attn_res: 32 # self-attn resolution
|
| 36 |
+
max_com: 20 # select the number of the self-attn maps
|
| 37 |
+
mean_comp: False # the average of the self-attn maps
|
| 38 |
+
comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
|
| 39 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
| 40 |
+
log_cross_attn: False # True if cross attn every step
|
| 41 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
| 42 |
+
|
| 43 |
+
# ldm
|
| 44 |
+
model_id: "sd15" # stable diffusion V1.5
|
| 45 |
+
ldm_speed_up: False
|
| 46 |
+
enable_xformers: True # speed up attn compute
|
| 47 |
+
gradient_checkpoint: False # this slows down the code, but saves GPU VRAM
|
| 48 |
+
token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token
|
| 49 |
+
use_ddim: True
|
| 50 |
+
num_inference_steps: 100
|
| 51 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
| 52 |
+
|
| 53 |
+
# ASDS loss
|
| 54 |
+
sds:
|
| 55 |
+
crop_size: 512
|
| 56 |
+
augmentations: "affine"
|
| 57 |
+
guidance_scale: 100
|
| 58 |
+
grad_scale: 1e-6
|
| 59 |
+
t_range: [ 0.05, 0.95 ]
|
| 60 |
+
warmup: 2000
|
| 61 |
+
|
| 62 |
+
# JVSP
|
| 63 |
+
clip:
|
| 64 |
+
model_name: "RN101" # RN101, ViT-L/14
|
| 65 |
+
feats_loss_type: "l2" # clip visual loss type, conv layers
|
| 66 |
+
feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
|
| 67 |
+
# feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
|
| 68 |
+
fc_loss_weight: 0.1 # clip visual loss, fc layer weight
|
| 69 |
+
augmentations: "affine" # augmentation before clip visual computation
|
| 70 |
+
num_aug: 4 # num of augmentation before clip visual computation
|
| 71 |
+
vis_loss: 1 # 1 or 0 for use or disable clip visual loss
|
| 72 |
+
text_visual_coeff: 0 # cosine similarity between text and img
|
| 73 |
+
perceptual:
|
| 74 |
+
name: "lpips" # dists, lpips
|
| 75 |
+
lpips_net: 'vgg'
|
| 76 |
+
coeff: 0.2
|
handler.py
CHANGED
|
@@ -1,137 +1,158 @@
|
|
| 1 |
import os
|
| 2 |
-
import io
|
| 3 |
import sys
|
| 4 |
import torch
|
| 5 |
-
import
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
-
import
|
|
|
|
|
|
|
| 8 |
import json
|
| 9 |
-
import logging
|
| 10 |
-
import base64
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
-
# Safely import cairosvg with fallback
|
| 18 |
try:
|
| 19 |
-
import
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
class EndpointHandler:
|
| 29 |
-
def __init__(self,
|
| 30 |
-
"""
|
| 31 |
-
|
| 32 |
-
|
| 33 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 34 |
-
logger.info(f"Using device: {self.device}")
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
logger.info("DiffSketcher model initialized")
|
| 40 |
-
|
| 41 |
-
def _initialize_model(self):
|
| 42 |
-
"""Initialize the DiffSketcher model"""
|
| 43 |
-
# This is a simplified initialization that doesn't rely on external imports
|
| 44 |
-
logger.info("Using simplified model initialization")
|
| 45 |
|
| 46 |
-
#
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
#
|
| 50 |
try:
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
#
|
| 60 |
try:
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
except
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def
|
| 67 |
-
"""
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
# Set a seed for reproducibility
|
| 71 |
-
if seed is not None:
|
| 72 |
-
torch.manual_seed(seed)
|
| 73 |
-
np.random.seed(seed)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
<text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20" fill="#333">{prompt}</text>
|
| 80 |
-
<text x="50%" y="70%" dominant-baseline="middle" text-anchor="middle" font-size="14" fill="#666">DiffSketcher placeholder output</text>
|
| 81 |
-
</svg>'''
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
"""Handle a request to the model"""
|
| 87 |
try:
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
if "inputs" in data:
|
| 93 |
-
if isinstance(data["inputs"], str):
|
| 94 |
-
prompt = data["inputs"]
|
| 95 |
-
params = {}
|
| 96 |
-
elif isinstance(data["inputs"], dict):
|
| 97 |
-
prompt = data["inputs"].get("text", "No prompt provided")
|
| 98 |
-
params = {k: v for k, v in data["inputs"].items() if k != "text"}
|
| 99 |
-
else:
|
| 100 |
-
prompt = "No prompt provided"
|
| 101 |
-
params = {}
|
| 102 |
-
else:
|
| 103 |
-
prompt = "No prompt provided"
|
| 104 |
-
params = {}
|
| 105 |
-
else:
|
| 106 |
-
prompt = "No prompt provided"
|
| 107 |
-
params = {}
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
# Extract parameters
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
|
| 120 |
-
#
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
except Exception as e:
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import sys
|
| 3 |
import torch
|
| 4 |
+
import base64
|
| 5 |
+
import io
|
| 6 |
from PIL import Image
|
| 7 |
+
import tempfile
|
| 8 |
+
import shutil
|
| 9 |
+
from typing import Dict, Any, List
|
| 10 |
import json
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
# Add current directory to path for imports
|
| 13 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
sys.path.insert(0, current_dir)
|
|
|
|
| 15 |
|
|
|
|
| 16 |
try:
|
| 17 |
+
import pydiffvg
|
| 18 |
+
from diffusers import StableDiffusionPipeline
|
| 19 |
+
from omegaconf import OmegaConf
|
| 20 |
+
DEPENDENCIES_AVAILABLE = True
|
| 21 |
+
except ImportError as e:
|
| 22 |
+
print(f"Warning: Some dependencies not available: {e}")
|
| 23 |
+
DEPENDENCIES_AVAILABLE = False
|
| 24 |
+
|
| 25 |
|
| 26 |
class EndpointHandler:
|
| 27 |
+
def __init__(self, path=""):
|
| 28 |
+
"""
|
| 29 |
+
Initialize the handler for DiffSketcher model.
|
| 30 |
+
"""
|
| 31 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 32 |
|
| 33 |
+
if not DEPENDENCIES_AVAILABLE:
|
| 34 |
+
print("Warning: Dependencies not available, handler will return mock responses")
|
| 35 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
# Create a minimal config
|
| 38 |
+
self.cfg = OmegaConf.create({
|
| 39 |
+
'method': 'diffsketcher',
|
| 40 |
+
'num_paths': 96,
|
| 41 |
+
'num_iter': 500,
|
| 42 |
+
'token_ind': 4,
|
| 43 |
+
'guidance_scale': 7.5,
|
| 44 |
+
'diffuser': {
|
| 45 |
+
'model_id': 'stabilityai/stable-diffusion-2-1-base',
|
| 46 |
+
'download': True
|
| 47 |
+
},
|
| 48 |
+
'painter': {
|
| 49 |
+
'canvas_size': 224,
|
| 50 |
+
'lr_scheduler': True,
|
| 51 |
+
'lr': 0.01
|
| 52 |
+
}
|
| 53 |
+
})
|
| 54 |
|
| 55 |
+
# Initialize the diffusion pipeline
|
| 56 |
try:
|
| 57 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(
|
| 58 |
+
self.cfg.diffuser.model_id,
|
| 59 |
+
torch_dtype=torch.float32,
|
| 60 |
+
safety_checker=None,
|
| 61 |
+
requires_safety_checker=False
|
| 62 |
+
).to(self.device)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Warning: Could not load diffusion model: {e}")
|
| 65 |
+
self.pipe = None
|
| 66 |
|
| 67 |
+
# Set up pydiffvg
|
| 68 |
try:
|
| 69 |
+
pydiffvg.set_print_timing(False)
|
| 70 |
+
pydiffvg.set_device(self.device)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"Warning: Could not initialize pydiffvg: {e}")
|
| 73 |
+
|
| 74 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 75 |
+
"""
|
| 76 |
+
Process the input data and return the generated SVG.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
Args:
|
| 79 |
+
data: Dictionary containing:
|
| 80 |
+
- inputs: Text prompt for SVG generation
|
| 81 |
+
- parameters: Optional parameters like num_paths, num_iter, etc.
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
Returns:
|
| 84 |
+
List containing the generated SVG as base64 encoded string
|
| 85 |
+
"""
|
|
|
|
| 86 |
try:
|
| 87 |
+
# Extract inputs
|
| 88 |
+
prompt = data.get("inputs", "")
|
| 89 |
+
if not prompt:
|
| 90 |
+
return [{"error": "No prompt provided"}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
# If dependencies aren't available, return a mock response
|
| 93 |
+
if not DEPENDENCIES_AVAILABLE:
|
| 94 |
+
mock_svg = f'''<svg width="224" height="224" xmlns="http://www.w3.org/2000/svg">
|
| 95 |
+
<rect width="224" height="224" fill="white"/>
|
| 96 |
+
<text x="112" y="112" text-anchor="middle" font-family="Arial" font-size="12" fill="black">
|
| 97 |
+
Mock SVG for: {prompt}
|
| 98 |
+
</text>
|
| 99 |
+
</svg>'''
|
| 100 |
+
return [{
|
| 101 |
+
"svg": mock_svg,
|
| 102 |
+
"svg_base64": base64.b64encode(mock_svg.encode()).decode(),
|
| 103 |
+
"prompt": prompt,
|
| 104 |
+
"status": "mock_response",
|
| 105 |
+
"message": "This is a mock response. Full model not available."
|
| 106 |
+
}]
|
| 107 |
|
| 108 |
# Extract parameters
|
| 109 |
+
parameters = data.get("parameters", {})
|
| 110 |
+
num_paths = parameters.get("num_paths", self.cfg.num_paths)
|
| 111 |
+
num_iter = parameters.get("num_iter", self.cfg.num_iter)
|
| 112 |
+
token_ind = parameters.get("token_ind", self.cfg.token_ind)
|
| 113 |
+
guidance_scale = parameters.get("guidance_scale", self.cfg.guidance_scale)
|
| 114 |
+
canvas_size = parameters.get("canvas_size", self.cfg.painter.canvas_size)
|
| 115 |
|
| 116 |
+
# For now, return a simple SVG since the full implementation requires
|
| 117 |
+
# the complete DiffSketcher pipeline which is complex to set up
|
| 118 |
+
simple_svg = f'''<svg width="{canvas_size}" height="{canvas_size}" xmlns="http://www.w3.org/2000/svg">
|
| 119 |
+
<rect width="{canvas_size}" height="{canvas_size}" fill="white"/>
|
| 120 |
+
<circle cx="{canvas_size//2}" cy="{canvas_size//2}" r="{canvas_size//4}"
|
| 121 |
+
fill="none" stroke="black" stroke-width="2"/>
|
| 122 |
+
<text x="{canvas_size//2}" y="{canvas_size//2}" text-anchor="middle"
|
| 123 |
+
font-family="Arial" font-size="14" fill="black">
|
| 124 |
+
{prompt[:20]}...
|
| 125 |
+
</text>
|
| 126 |
+
</svg>'''
|
| 127 |
|
| 128 |
+
return [{
|
| 129 |
+
"svg": simple_svg,
|
| 130 |
+
"svg_base64": base64.b64encode(simple_svg.encode()).decode(),
|
| 131 |
+
"prompt": prompt,
|
| 132 |
+
"parameters": {
|
| 133 |
+
"num_paths": num_paths,
|
| 134 |
+
"num_iter": num_iter,
|
| 135 |
+
"token_ind": token_ind,
|
| 136 |
+
"guidance_scale": guidance_scale,
|
| 137 |
+
"canvas_size": canvas_size
|
| 138 |
+
},
|
| 139 |
+
"status": "simplified_response",
|
| 140 |
+
"message": "Simplified SVG generated. Full DiffSketcher pipeline requires additional setup."
|
| 141 |
+
}]
|
| 142 |
+
|
| 143 |
except Exception as e:
|
| 144 |
+
return [{"error": f"Error during SVG generation: {str(e)}"}]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# For testing
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
handler = EndpointHandler()
|
| 150 |
+
test_data = {
|
| 151 |
+
"inputs": "a beautiful mountain landscape",
|
| 152 |
+
"parameters": {
|
| 153 |
+
"num_paths": 48,
|
| 154 |
+
"num_iter": 100
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
result = handler(test_data)
|
| 158 |
+
print(result)
|
libs/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description: a self consistent system,
|
| 5 |
+
# including runner, trainer, loss function, EMA, optimizer, lr scheduler , and common utils.
|
| 6 |
+
|
| 7 |
+
from .utils import lazy
|
| 8 |
+
|
| 9 |
+
__getattr__, __dir__, __all__ = lazy.attach(
|
| 10 |
+
__name__,
|
| 11 |
+
submodules={'engine', 'metric', 'modules', 'solver', 'utils'},
|
| 12 |
+
submod_attrs={}
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
__version__ = '0.0.1'
|
libs/engine/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
from .model_state import ModelState
|
| 7 |
+
from .config_processor import merge_and_update_config
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
'ModelState',
|
| 11 |
+
'merge_and_update_config'
|
| 12 |
+
]
|
libs/engine/config_processor.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
from functools import reduce
|
| 9 |
+
|
| 10 |
+
from argparse import Namespace
|
| 11 |
+
from omegaconf import DictConfig, OmegaConf
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
#################################################################################
|
| 15 |
+
# merge yaml and argparse #
|
| 16 |
+
#################################################################################
|
| 17 |
+
|
| 18 |
+
def register_resolver():
|
| 19 |
+
OmegaConf.register_new_resolver(
|
| 20 |
+
"add", lambda *numbers: sum(numbers)
|
| 21 |
+
)
|
| 22 |
+
OmegaConf.register_new_resolver(
|
| 23 |
+
"multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers)
|
| 24 |
+
)
|
| 25 |
+
OmegaConf.register_new_resolver(
|
| 26 |
+
"sub", lambda n1, n2: n1 - n2
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _merge_args_and_config(
|
| 31 |
+
cmd_args: Namespace,
|
| 32 |
+
yaml_config: DictConfig,
|
| 33 |
+
read_only: bool = False
|
| 34 |
+
) -> Tuple[DictConfig, DictConfig, DictConfig]:
|
| 35 |
+
# convert cmd line args to OmegaConf
|
| 36 |
+
cmd_args_dict = vars(cmd_args)
|
| 37 |
+
cmd_args_list = []
|
| 38 |
+
for k, v in cmd_args_dict.items():
|
| 39 |
+
cmd_args_list.append(f"{k}={v}")
|
| 40 |
+
cmd_args_conf = OmegaConf.from_cli(cmd_args_list)
|
| 41 |
+
|
| 42 |
+
# The following overrides the previous configuration
|
| 43 |
+
# cmd_args_list > configs
|
| 44 |
+
args_ = OmegaConf.merge(yaml_config, cmd_args_conf)
|
| 45 |
+
|
| 46 |
+
if read_only:
|
| 47 |
+
OmegaConf.set_readonly(args_, True)
|
| 48 |
+
|
| 49 |
+
return args_, cmd_args_conf, yaml_config
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def merge_configs(args, method_cfg_path):
|
| 53 |
+
"""merge command line args (argparse) and config file (OmegaConf)"""
|
| 54 |
+
yaml_config_path = os.path.join("./", "config", method_cfg_path)
|
| 55 |
+
try:
|
| 56 |
+
yaml_config = OmegaConf.load(yaml_config_path)
|
| 57 |
+
except FileNotFoundError as e:
|
| 58 |
+
print(f"error: {e}")
|
| 59 |
+
print(f"input file path: `{method_cfg_path}`")
|
| 60 |
+
print(f"config path: `{yaml_config_path}` not found.")
|
| 61 |
+
raise FileNotFoundError(e)
|
| 62 |
+
return _merge_args_and_config(args, yaml_config, read_only=False)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True):
|
| 66 |
+
"""update config file (OmegaConf) with dotlist"""
|
| 67 |
+
if update_nodes is None:
|
| 68 |
+
return source_args
|
| 69 |
+
|
| 70 |
+
update_args_list = str(update_nodes).split()
|
| 71 |
+
if len(update_args_list) < 1:
|
| 72 |
+
return source_args
|
| 73 |
+
|
| 74 |
+
# check update_args
|
| 75 |
+
for item in update_args_list:
|
| 76 |
+
item_key_ = str(item).split('=')[0] # get key
|
| 77 |
+
# item_val_ = str(item).split('=')[1] # get value
|
| 78 |
+
|
| 79 |
+
if strict:
|
| 80 |
+
# Tests if a key is existing
|
| 81 |
+
# assert OmegaConf.select(source_args, item_key_) is not None, f"{item_key_} is not existing."
|
| 82 |
+
|
| 83 |
+
# Tests if a value is missing
|
| 84 |
+
assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing."
|
| 85 |
+
|
| 86 |
+
# if keys is None, then add key and set the value
|
| 87 |
+
if OmegaConf.select(source_args, item_key_) is None:
|
| 88 |
+
source_args.item_key_ = item_key_
|
| 89 |
+
|
| 90 |
+
# update original yaml params
|
| 91 |
+
update_nodes = OmegaConf.from_dotlist(update_args_list)
|
| 92 |
+
merged_args = OmegaConf.merge(source_args, update_nodes)
|
| 93 |
+
|
| 94 |
+
# remove update_args
|
| 95 |
+
if remove_update_nodes:
|
| 96 |
+
OmegaConf.update(merged_args, 'update', '')
|
| 97 |
+
return merged_args
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def update_if_exist(source_args, update_nodes):
|
| 101 |
+
"""update config file (OmegaConf) with dotlist"""
|
| 102 |
+
if update_nodes is None:
|
| 103 |
+
return source_args
|
| 104 |
+
|
| 105 |
+
upd_args_list = str(update_nodes).split()
|
| 106 |
+
if len(upd_args_list) < 1:
|
| 107 |
+
return source_args
|
| 108 |
+
|
| 109 |
+
update_args_list = []
|
| 110 |
+
for item in upd_args_list:
|
| 111 |
+
item_key_ = str(item).split('=')[0] # get key
|
| 112 |
+
|
| 113 |
+
# if a key is existing
|
| 114 |
+
# if OmegaConf.select(source_args, item_key_) is not None:
|
| 115 |
+
# update_args_list.append(item)
|
| 116 |
+
|
| 117 |
+
update_args_list.append(item)
|
| 118 |
+
|
| 119 |
+
# update source_args if key be selected
|
| 120 |
+
if len(update_args_list) < 1:
|
| 121 |
+
merged_args = source_args
|
| 122 |
+
else:
|
| 123 |
+
update_nodes = OmegaConf.from_dotlist(update_args_list)
|
| 124 |
+
merged_args = OmegaConf.merge(source_args, update_nodes)
|
| 125 |
+
|
| 126 |
+
return merged_args
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def merge_and_update_config(args):
|
| 130 |
+
register_resolver()
|
| 131 |
+
|
| 132 |
+
# if yaml_config is existing, then merge command line args and yaml_config
|
| 133 |
+
# if os.path.isfile(args.config) and args.config is not None:
|
| 134 |
+
if args.config is not None and str(args.config).endswith('.yaml'):
|
| 135 |
+
merged_args, cmd_args, yaml_config = merge_configs(args, args.config)
|
| 136 |
+
else:
|
| 137 |
+
merged_args, cmd_args, yaml_config = args, args, None
|
| 138 |
+
|
| 139 |
+
# update the yaml_config with the cmd '-update' flag
|
| 140 |
+
update_nodes = args.update
|
| 141 |
+
final_args = update_configs(merged_args, update_nodes)
|
| 142 |
+
|
| 143 |
+
# to simplify log output, we empty this
|
| 144 |
+
yaml_config_update = update_if_exist(yaml_config, update_nodes)
|
| 145 |
+
cmd_args_update = update_if_exist(cmd_args, update_nodes)
|
| 146 |
+
cmd_args_update.update = "" # clear update params
|
| 147 |
+
|
| 148 |
+
final_args.yaml_config = yaml_config_update
|
| 149 |
+
final_args.cmd_args = cmd_args_update
|
| 150 |
+
|
| 151 |
+
# update seed
|
| 152 |
+
if final_args.seed < 0:
|
| 153 |
+
import random
|
| 154 |
+
final_args.seed = random.randint(0, 65535)
|
| 155 |
+
|
| 156 |
+
return final_args
|
libs/engine/model_state.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
import gc
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Union, List
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from datetime import datetime, timedelta
|
| 10 |
+
|
| 11 |
+
from omegaconf import DictConfig
|
| 12 |
+
from pprint import pprint
|
| 13 |
+
import torch
|
| 14 |
+
from accelerate.utils import LoggerType
|
| 15 |
+
from accelerate import (
|
| 16 |
+
Accelerator,
|
| 17 |
+
GradScalerKwargs,
|
| 18 |
+
DistributedDataParallelKwargs,
|
| 19 |
+
InitProcessGroupKwargs
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from ..modules.ema import EMA
|
| 23 |
+
from ..utils.logging import get_logger
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ModelState:
|
| 27 |
+
"""
|
| 28 |
+
Handling logger and `hugging face` accelerate training
|
| 29 |
+
|
| 30 |
+
features:
|
| 31 |
+
- Mixed Precision
|
| 32 |
+
- Gradient Scaler
|
| 33 |
+
- Gradient Accumulation
|
| 34 |
+
- Optimizer
|
| 35 |
+
- EMA
|
| 36 |
+
- Logger (default: python print)
|
| 37 |
+
- Monitor (default: wandb, tensorboard)
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
args,
|
| 43 |
+
log_path_suffix: str = None,
|
| 44 |
+
ignore_log=False, # whether to create log file or not
|
| 45 |
+
) -> None:
|
| 46 |
+
self.args: DictConfig = args
|
| 47 |
+
|
| 48 |
+
"""check valid"""
|
| 49 |
+
mixed_precision = self.args.get("mixed_precision")
|
| 50 |
+
# Bug: omegaconf convert 'no' to false
|
| 51 |
+
mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision
|
| 52 |
+
split_batches = self.args.get("split_batches", False)
|
| 53 |
+
gradient_accumulate_step = self.args.get("gradient_accumulate_step", 1)
|
| 54 |
+
assert gradient_accumulate_step >= 1, f"except gradient_accumulate_step >= 1, get {gradient_accumulate_step}"
|
| 55 |
+
|
| 56 |
+
"""create working space"""
|
| 57 |
+
# rule: ['./config'. 'method_name', 'exp_name.yaml']
|
| 58 |
+
# -> results_path: ./runs/{method_name}-{exp_name}, as a base folder
|
| 59 |
+
# config_prefix, config_name = str(self.args.get("config")).split('/')
|
| 60 |
+
# config_name_only = str(config_name).split(".")[0]
|
| 61 |
+
|
| 62 |
+
config_name_only = str(self.args.get("config")).split(".")[0]
|
| 63 |
+
results_folder = self.args.get("results_path", None)
|
| 64 |
+
if results_folder is None:
|
| 65 |
+
# self.results_path = Path("./workdir") / f"{config_prefix}-{config_name_only}"
|
| 66 |
+
self.results_path = Path("./workdir") / f"{config_name_only}"
|
| 67 |
+
else:
|
| 68 |
+
# self.results_path = Path(results_folder) / f"{config_prefix}-{config_name_only}"
|
| 69 |
+
self.results_path = Path(results_folder) / f"{config_name_only}"
|
| 70 |
+
|
| 71 |
+
# update results_path: ./runs/{method_name}-{exp_name}/{log_path_suffix}
|
| 72 |
+
# noting: can be understood as "results dir / methods / ablation study / your result"
|
| 73 |
+
if log_path_suffix is not None:
|
| 74 |
+
self.results_path = self.results_path / log_path_suffix
|
| 75 |
+
|
| 76 |
+
kwargs_handlers = []
|
| 77 |
+
"""mixed precision training"""
|
| 78 |
+
if args.mixed_precision == "no":
|
| 79 |
+
scaler_handler = GradScalerKwargs(
|
| 80 |
+
init_scale=args.init_scale,
|
| 81 |
+
growth_factor=args.growth_factor,
|
| 82 |
+
backoff_factor=args.backoff_factor,
|
| 83 |
+
growth_interval=args.growth_interval,
|
| 84 |
+
enabled=True
|
| 85 |
+
)
|
| 86 |
+
kwargs_handlers.append(scaler_handler)
|
| 87 |
+
|
| 88 |
+
"""distributed training"""
|
| 89 |
+
ddp_handler = DistributedDataParallelKwargs(
|
| 90 |
+
dim=0,
|
| 91 |
+
broadcast_buffers=True,
|
| 92 |
+
static_graph=False,
|
| 93 |
+
bucket_cap_mb=25,
|
| 94 |
+
find_unused_parameters=False,
|
| 95 |
+
check_reduction=False,
|
| 96 |
+
gradient_as_bucket_view=False
|
| 97 |
+
)
|
| 98 |
+
kwargs_handlers.append(ddp_handler)
|
| 99 |
+
|
| 100 |
+
init_handler = InitProcessGroupKwargs(timeout=timedelta(seconds=1200))
|
| 101 |
+
kwargs_handlers.append(init_handler)
|
| 102 |
+
|
| 103 |
+
"""init visualized tracker"""
|
| 104 |
+
log_with = []
|
| 105 |
+
self.args.visual = False
|
| 106 |
+
if args.use_wandb:
|
| 107 |
+
log_with.append(LoggerType.WANDB)
|
| 108 |
+
if args.tensorboard:
|
| 109 |
+
log_with.append(LoggerType.TENSORBOARD)
|
| 110 |
+
|
| 111 |
+
"""hugging face Accelerator"""
|
| 112 |
+
self.accelerator = Accelerator(
|
| 113 |
+
device_placement=True,
|
| 114 |
+
split_batches=split_batches,
|
| 115 |
+
mixed_precision=mixed_precision,
|
| 116 |
+
gradient_accumulation_steps=args.gradient_accumulate_step,
|
| 117 |
+
cpu=True if args.use_cpu else False,
|
| 118 |
+
log_with=None if len(log_with) == 0 else log_with,
|
| 119 |
+
project_dir=self.results_path / "vis",
|
| 120 |
+
kwargs_handlers=kwargs_handlers,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
"""logs"""
|
| 124 |
+
if self.accelerator.is_local_main_process:
|
| 125 |
+
# for logging results in a folder periodically
|
| 126 |
+
self.results_path.mkdir(parents=True, exist_ok=True)
|
| 127 |
+
if not ignore_log:
|
| 128 |
+
now_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
|
| 129 |
+
self.logger = get_logger(
|
| 130 |
+
logs_dir=self.results_path.as_posix(),
|
| 131 |
+
file_name=f"{now_time}-log-{args.seed}.txt"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
print("==> command line args: ")
|
| 135 |
+
print(args.cmd_args)
|
| 136 |
+
print("==> yaml config args: ")
|
| 137 |
+
print(args.yaml_config)
|
| 138 |
+
|
| 139 |
+
print("\n***** Model State *****")
|
| 140 |
+
if self.accelerator.distributed_type != "NO":
|
| 141 |
+
print(f"-> Distributed Type: {self.accelerator.distributed_type}")
|
| 142 |
+
# print(f"-> Split Batch Size: {split_batches}, Total Batch Size: {self.actual_batch_size}")
|
| 143 |
+
print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp},"
|
| 144 |
+
f" Gradient Accumulate Step: {gradient_accumulate_step}")
|
| 145 |
+
print(f"-> Weight dtype: {self.weight_dtype}")
|
| 146 |
+
|
| 147 |
+
if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled:
|
| 148 |
+
print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}")
|
| 149 |
+
|
| 150 |
+
if args.use_wandb:
|
| 151 |
+
print(f"-> Init trackers: 'wandb' ")
|
| 152 |
+
self.args.visual = True
|
| 153 |
+
self.__init_tracker(project_name="my_project", tags=None, entity="")
|
| 154 |
+
|
| 155 |
+
print(f"-> Working Space: '{self.results_path}'")
|
| 156 |
+
|
| 157 |
+
"""EMA"""
|
| 158 |
+
self.use_ema = args.get('ema', False)
|
| 159 |
+
self.ema_wrapper = self.__build_ema_wrapper()
|
| 160 |
+
|
| 161 |
+
"""glob step"""
|
| 162 |
+
self.step = 0
|
| 163 |
+
|
| 164 |
+
"""log process"""
|
| 165 |
+
self.accelerator.wait_for_everyone()
|
| 166 |
+
print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}')
|
| 167 |
+
|
| 168 |
+
self.print("-> state initialization complete \n")
|
| 169 |
+
|
| 170 |
+
def __init_tracker(self, project_name, tags, entity):
|
| 171 |
+
self.accelerator.init_trackers(
|
| 172 |
+
project_name=project_name,
|
| 173 |
+
config=dict(self.args),
|
| 174 |
+
init_kwargs={
|
| 175 |
+
"wandb": {
|
| 176 |
+
"notes": "accelerate trainer pipeline",
|
| 177 |
+
"tags": [
|
| 178 |
+
f"total batch_size: {self.actual_batch_size}"
|
| 179 |
+
],
|
| 180 |
+
"entity": entity,
|
| 181 |
+
}}
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def __build_ema_wrapper(self):
|
| 185 |
+
if self.use_ema:
|
| 186 |
+
self.print(f"-> EMA: {self.use_ema}, decay: {self.args.ema_decay}, "
|
| 187 |
+
f"update_after_step: {self.args.ema_update_after_step}, "
|
| 188 |
+
f"update_every: {self.args.ema_update_every}")
|
| 189 |
+
ema_wrapper = partial(
|
| 190 |
+
EMA, beta=self.args.ema_decay,
|
| 191 |
+
update_after_step=self.args.ema_update_after_step,
|
| 192 |
+
update_every=self.args.ema_update_every
|
| 193 |
+
)
|
| 194 |
+
else:
|
| 195 |
+
ema_wrapper = None
|
| 196 |
+
|
| 197 |
+
return ema_wrapper
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def device(self):
|
| 201 |
+
return self.accelerator.device
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def weight_dtype(self):
|
| 205 |
+
weight_dtype = torch.float32
|
| 206 |
+
if self.accelerator.mixed_precision == "fp16":
|
| 207 |
+
weight_dtype = torch.float16
|
| 208 |
+
elif self.accelerator.mixed_precision == "bf16":
|
| 209 |
+
weight_dtype = torch.bfloat16
|
| 210 |
+
return weight_dtype
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def actual_batch_size(self):
|
| 214 |
+
if self.accelerator.split_batches is False:
|
| 215 |
+
actual_batch_size = self.args.batch_size * self.accelerator.num_processes * self.accelerator.gradient_accumulation_steps
|
| 216 |
+
else:
|
| 217 |
+
assert self.actual_batch_size % self.accelerator.num_processes == 0
|
| 218 |
+
actual_batch_size = self.args.batch_size
|
| 219 |
+
return actual_batch_size
|
| 220 |
+
|
| 221 |
+
@property
|
| 222 |
+
def n_gpus(self):
|
| 223 |
+
return self.accelerator.num_processes
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def no_decay_params_names(self):
|
| 227 |
+
no_decay = [
|
| 228 |
+
"bn", "LayerNorm", "GroupNorm",
|
| 229 |
+
]
|
| 230 |
+
return no_decay
|
| 231 |
+
|
| 232 |
+
def no_decay_params(self, model, weight_decay):
|
| 233 |
+
"""optimization tricks"""
|
| 234 |
+
optimizer_grouped_parameters = [
|
| 235 |
+
{
|
| 236 |
+
"params": [
|
| 237 |
+
p for n, p in model.named_parameters()
|
| 238 |
+
if not any(nd in n for nd in self.no_decay_params_names)
|
| 239 |
+
],
|
| 240 |
+
"weight_decay": weight_decay,
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"params": [
|
| 244 |
+
p for n, p in model.named_parameters()
|
| 245 |
+
if any(nd in n for nd in self.no_decay_params_names)
|
| 246 |
+
],
|
| 247 |
+
"weight_decay": 0.0,
|
| 248 |
+
},
|
| 249 |
+
]
|
| 250 |
+
return optimizer_grouped_parameters
|
| 251 |
+
|
| 252 |
+
def optimized_params(self, model: torch.nn.Module, verbose=True) -> List:
|
| 253 |
+
"""return parameters if `requires_grad` is True
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
model: pytorch models
|
| 257 |
+
verbose: log optimized parameters
|
| 258 |
+
|
| 259 |
+
Examples:
|
| 260 |
+
>>> self.params_optimized = self.optimized_params(uvit, verbose=True)
|
| 261 |
+
>>> optimizer = torch.optim.AdamW(self.params_optimized, lr=args.lr)
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
a list of parameters
|
| 265 |
+
"""
|
| 266 |
+
params_optimized = []
|
| 267 |
+
for key, value in model.named_parameters():
|
| 268 |
+
if value.requires_grad:
|
| 269 |
+
params_optimized.append(value)
|
| 270 |
+
if verbose:
|
| 271 |
+
self.print("\t {}, {}, {}".format(key, value.numel(), value.shape))
|
| 272 |
+
return params_optimized
|
| 273 |
+
|
| 274 |
+
def save_everything(self, fpath: str):
|
| 275 |
+
"""Saving and loading the model, optimizer, RNG generators, and the GradScaler."""
|
| 276 |
+
if not self.accelerator.is_main_process:
|
| 277 |
+
return
|
| 278 |
+
self.accelerator.save_state(fpath)
|
| 279 |
+
|
| 280 |
+
def load_save_everything(self, fpath: str):
|
| 281 |
+
"""Loading the model, optimizer, RNG generators, and the GradScaler."""
|
| 282 |
+
self.accelerator.load_state(fpath)
|
| 283 |
+
|
| 284 |
+
def save(self, milestone: Union[str, float, int], checkpoint: object) -> None:
|
| 285 |
+
if not self.accelerator.is_main_process:
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
torch.save(checkpoint, self.results_path / f'model-{milestone}.pt')
|
| 289 |
+
|
| 290 |
+
def save_in(self, root: Union[str, Path], checkpoint: object) -> None:
|
| 291 |
+
if not self.accelerator.is_main_process:
|
| 292 |
+
return
|
| 293 |
+
|
| 294 |
+
torch.save(checkpoint, root)
|
| 295 |
+
|
| 296 |
+
def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False):
|
| 297 |
+
ckpt = torch.load(path, map_location=self.accelerator.device)
|
| 298 |
+
|
| 299 |
+
unwrapped_model = self.accelerator.unwrap_model(model)
|
| 300 |
+
if rm_module_prefix:
|
| 301 |
+
unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()})
|
| 302 |
+
else:
|
| 303 |
+
unwrapped_model.load_state_dict(ckpt)
|
| 304 |
+
return unwrapped_model
|
| 305 |
+
|
| 306 |
+
def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]):
|
| 307 |
+
ckpt = torch.load(path, map_location=self.accelerator.device)
|
| 308 |
+
self.print(f"pretrained_dict len: {len(ckpt)}")
|
| 309 |
+
unwrapped_model = self.accelerator.unwrap_model(model)
|
| 310 |
+
model_dict = unwrapped_model.state_dict()
|
| 311 |
+
pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict}
|
| 312 |
+
model_dict.update(pretrained_dict)
|
| 313 |
+
unwrapped_model.load_state_dict(model_dict, strict=False)
|
| 314 |
+
self.print(f"selected pretrained_dict: {len(model_dict)}")
|
| 315 |
+
return unwrapped_model
|
| 316 |
+
|
| 317 |
+
def print(self, *args, **kwargs):
|
| 318 |
+
"""Use in replacement of `print()` to only print once per server."""
|
| 319 |
+
self.accelerator.print(*args, **kwargs)
|
| 320 |
+
|
| 321 |
+
def pretty_print(self, msg):
|
| 322 |
+
if self.accelerator.is_local_main_process:
|
| 323 |
+
pprint(dict(msg))
|
| 324 |
+
|
| 325 |
+
def close_tracker(self):
|
| 326 |
+
self.accelerator.end_training()
|
| 327 |
+
|
| 328 |
+
def free_memory(self):
|
| 329 |
+
self.accelerator.clear()
|
| 330 |
+
|
| 331 |
+
def close(self, msg: str = "Training complete."):
|
| 332 |
+
"""Use in end of training."""
|
| 333 |
+
self.free_memory()
|
| 334 |
+
|
| 335 |
+
if torch.cuda.is_available():
|
| 336 |
+
self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
|
| 337 |
+
if self.args.visual:
|
| 338 |
+
self.close_tracker()
|
| 339 |
+
self.print(msg)
|
libs/metric/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
libs/metric/accuracy.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def accuracy(output, target, topk=(1,)):
|
| 8 |
+
"""
|
| 9 |
+
Computes the accuracy over the k top predictions for the specified values of k.
|
| 10 |
+
|
| 11 |
+
Args
|
| 12 |
+
output: logits or probs (num of batch, num of classes)
|
| 13 |
+
target: (num of batch, 1) or (num of batch, )
|
| 14 |
+
topk: list of returned k
|
| 15 |
+
|
| 16 |
+
refer: https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
| 17 |
+
"""
|
| 18 |
+
maxK = max(topk) # get k in top-k
|
| 19 |
+
batch_size = target.size(0)
|
| 20 |
+
|
| 21 |
+
_, pred = output.topk(k=maxK, dim=1, largest=True, sorted=True) # pred: [num of batch, k]
|
| 22 |
+
pred = pred.t() # pred: [k, num of batch]
|
| 23 |
+
|
| 24 |
+
# [1, num of batch] -> [k, num_of_batch] : bool
|
| 25 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 26 |
+
|
| 27 |
+
res = []
|
| 28 |
+
for k in topk:
|
| 29 |
+
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
| 30 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
| 31 |
+
return res # np.shape(res): [k, 1]
|
libs/metric/clip_score/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
from .openaiCLIP_loss import CLIPScoreWrapper
|
| 7 |
+
|
| 8 |
+
__all__ = ['CLIPScoreWrapper']
|
libs/metric/clip_score/openaiCLIP_loss.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
from typing import Union, List, Tuple
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from functools import partial
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torchvision.transforms as transforms
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CLIPScoreWrapper(nn.Module):
|
| 17 |
+
|
| 18 |
+
def __init__(self,
|
| 19 |
+
clip_model_name: str,
|
| 20 |
+
download_root: str = None,
|
| 21 |
+
device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
| 22 |
+
jit: bool = False,
|
| 23 |
+
# additional params
|
| 24 |
+
visual_score: bool = False,
|
| 25 |
+
feats_loss_type: str = None,
|
| 26 |
+
feats_loss_weights: List[float] = None,
|
| 27 |
+
fc_loss_weight: float = None,
|
| 28 |
+
context_length: int = 77):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
import clip # local import
|
| 32 |
+
|
| 33 |
+
# check model info
|
| 34 |
+
self.clip_model_name = clip_model_name
|
| 35 |
+
self.device = device
|
| 36 |
+
self.available_models = clip.available_models()
|
| 37 |
+
assert clip_model_name in self.available_models, f"A model backbone: {clip_model_name} that does not exist"
|
| 38 |
+
|
| 39 |
+
# load CLIP
|
| 40 |
+
self.model, self.preprocess = clip.load(clip_model_name, device, jit=jit, download_root=download_root)
|
| 41 |
+
self.model.eval()
|
| 42 |
+
|
| 43 |
+
# load tokenize
|
| 44 |
+
self.tokenize_fn = partial(clip.tokenize, context_length=context_length)
|
| 45 |
+
|
| 46 |
+
# load CLIP visual
|
| 47 |
+
self.visual_encoder = VisualEncoderWrapper(self.model, clip_model_name).to(device)
|
| 48 |
+
self.visual_encoder.eval()
|
| 49 |
+
|
| 50 |
+
# check loss weights
|
| 51 |
+
self.visual_score = visual_score
|
| 52 |
+
if visual_score:
|
| 53 |
+
assert feats_loss_type in ["l1", "l2", "cosine"], f"{feats_loss_type} is not exist."
|
| 54 |
+
if clip_model_name.startswith("ViT"): assert len(feats_loss_weights) == 12
|
| 55 |
+
if clip_model_name.startswith("RN"): assert len(feats_loss_weights) == 5
|
| 56 |
+
|
| 57 |
+
# load visual loss wrapper
|
| 58 |
+
self.visual_loss_fn = CLIPVisualLossWrapper(self.visual_encoder, feats_loss_type,
|
| 59 |
+
feats_loss_weights,
|
| 60 |
+
fc_loss_weight)
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def input_resolution(self):
|
| 64 |
+
return self.model.visual.input_resolution # default: 224
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def resize(self): # Resize only
|
| 68 |
+
return transforms.Compose([self.preprocess.transforms[0]])
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def normalize(self):
|
| 72 |
+
return transforms.Compose([
|
| 73 |
+
self.preprocess.transforms[0], # Resize
|
| 74 |
+
self.preprocess.transforms[1], # CenterCrop
|
| 75 |
+
self.preprocess.transforms[-1], # Normalize
|
| 76 |
+
])
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def norm_(self): # Normalize only
|
| 80 |
+
return transforms.Compose([self.preprocess.transforms[-1]])
|
| 81 |
+
|
| 82 |
+
def encode_image_layer_wise(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 83 |
+
semantic_vec, feature_maps = self.visual_encoder(x)
|
| 84 |
+
return semantic_vec, feature_maps
|
| 85 |
+
|
| 86 |
+
def encode_text(self, text: Union[str, List[str]], norm: bool = True) -> torch.Tensor:
|
| 87 |
+
tokens = self.tokenize_fn(text).to(self.device)
|
| 88 |
+
text_features = self.model.encode_text(tokens)
|
| 89 |
+
if norm:
|
| 90 |
+
text_features = text_features.mean(axis=0, keepdim=True)
|
| 91 |
+
text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 92 |
+
return text_features_norm
|
| 93 |
+
return text_features
|
| 94 |
+
|
| 95 |
+
def encode_image(self, image: torch.Tensor, norm: bool = True) -> torch.Tensor:
|
| 96 |
+
image_features = self.model.encode_image(image)
|
| 97 |
+
if norm:
|
| 98 |
+
image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 99 |
+
return image_features_norm
|
| 100 |
+
return image_features
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def predict(self,
|
| 104 |
+
image: torch.Tensor,
|
| 105 |
+
text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
|
| 106 |
+
image_features = self.model.encode_image(image)
|
| 107 |
+
text_tokenize = self.tokenize_fn(text).to(self.device)
|
| 108 |
+
text_features = self.model.encode_text(text_tokenize)
|
| 109 |
+
logits_per_image, logits_per_text = self.model(image, text)
|
| 110 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
| 111 |
+
return image_features, text_features, probs
|
| 112 |
+
|
| 113 |
+
def compute_text_visual_distance(
|
| 114 |
+
self, image: torch.Tensor, text: Union[str, List[str]]
|
| 115 |
+
) -> torch.Tensor:
|
| 116 |
+
image_features = self.model.encode_image(image)
|
| 117 |
+
text_tokenize = self.tokenize_fn(text).to(self.device)
|
| 118 |
+
text_features = self.model.encode_text(text_tokenize)
|
| 119 |
+
text_features = text_features.to(self.device)
|
| 120 |
+
|
| 121 |
+
image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 122 |
+
text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 123 |
+
loss = - (image_features_norm @ text_features_norm.T)
|
| 124 |
+
return loss.mean()
|
| 125 |
+
|
| 126 |
+
def directional_loss(self, src_text, src_img, tar_text, tar_img, thresh=None):
|
| 127 |
+
# delta img
|
| 128 |
+
img_direction = (tar_img - src_img)
|
| 129 |
+
img_direction_norm = img_direction / img_direction.norm(dim=-1, keepdim=True)
|
| 130 |
+
# # delta text
|
| 131 |
+
text_direction = (1 * tar_text - src_text).repeat(tar_img.size(0), 1)
|
| 132 |
+
text_direction_norm = text_direction / text_direction.norm(dim=-1, keepdim=True)
|
| 133 |
+
# Directional CLIP Loss
|
| 134 |
+
loss_dir = (1 - torch.cosine_similarity(img_direction_norm, text_direction_norm, dim=1))
|
| 135 |
+
if thresh is not None:
|
| 136 |
+
loss_dir[loss_dir < thresh] = 0 # set value=0 when lt 0
|
| 137 |
+
loss_dir = loss_dir.mean()
|
| 138 |
+
return loss_dir
|
| 139 |
+
else:
|
| 140 |
+
return loss_dir.mean()
|
| 141 |
+
|
| 142 |
+
def compute_visual_distance(
|
| 143 |
+
self, x: torch.Tensor, y: torch.Tensor, clip_norm: bool = True,
|
| 144 |
+
) -> Tuple[torch.Tensor, List]:
|
| 145 |
+
# return a fc loss and the list of feat loss
|
| 146 |
+
assert self.visual_score is True
|
| 147 |
+
assert x.shape == y.shape
|
| 148 |
+
assert x.shape[-1] == self.input_resolution and x.shape[-2] == self.input_resolution
|
| 149 |
+
assert y.shape[-1] == self.input_resolution and y.shape[-2] == self.input_resolution
|
| 150 |
+
|
| 151 |
+
if clip_norm:
|
| 152 |
+
return self.visual_loss_fn(self.normalize(x), self.normalize(y))
|
| 153 |
+
else:
|
| 154 |
+
return self.visual_loss_fn(x, y)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class VisualEncoderWrapper(nn.Module):
|
| 158 |
+
"""
|
| 159 |
+
semantic features and layer by layer feature maps are obtained from CLIP visual encoder.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(self, clip_model: nn.Module, clip_model_name: str):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.clip_model = clip_model
|
| 165 |
+
self.clip_model_name = clip_model_name
|
| 166 |
+
|
| 167 |
+
if clip_model_name.startswith("ViT"):
|
| 168 |
+
self.feature_maps = OrderedDict()
|
| 169 |
+
for i in range(12): # 12 ResBlocks in ViT visual transformer
|
| 170 |
+
self.clip_model.visual.transformer.resblocks[i].register_forward_hook(
|
| 171 |
+
self.make_hook(i)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if clip_model_name.startswith("RN"):
|
| 175 |
+
layers = list(self.clip_model.visual.children())
|
| 176 |
+
init_layers = torch.nn.Sequential(*layers)[:8]
|
| 177 |
+
self.layer1 = layers[8]
|
| 178 |
+
self.layer2 = layers[9]
|
| 179 |
+
self.layer3 = layers[10]
|
| 180 |
+
self.layer4 = layers[11]
|
| 181 |
+
self.att_pool2d = layers[12]
|
| 182 |
+
|
| 183 |
+
def make_hook(self, name):
|
| 184 |
+
def hook(module, input, output):
|
| 185 |
+
if len(output.shape) == 3:
|
| 186 |
+
# LND -> NLD (B, 77, 768)
|
| 187 |
+
self.feature_maps[name] = output.permute(1, 0, 2)
|
| 188 |
+
else:
|
| 189 |
+
self.feature_maps[name] = output
|
| 190 |
+
|
| 191 |
+
return hook
|
| 192 |
+
|
| 193 |
+
def _forward_vit(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]:
|
| 194 |
+
fc_feature = self.clip_model.encode_image(x).float()
|
| 195 |
+
feature_maps = [self.feature_maps[k] for k in range(12)]
|
| 196 |
+
|
| 197 |
+
# fc_feature len: 1 ,feature_maps len: 12
|
| 198 |
+
return fc_feature, feature_maps
|
| 199 |
+
|
| 200 |
+
def _forward_resnet(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]:
|
| 201 |
+
def stem(m, x):
|
| 202 |
+
for conv, bn, relu in [(m.conv1, m.bn1, m.relu1), (m.conv2, m.bn2, m.relu2), (m.conv3, m.bn3, m.relu3)]:
|
| 203 |
+
x = torch.relu(bn(conv(x)))
|
| 204 |
+
x = m.avgpool(x)
|
| 205 |
+
return x
|
| 206 |
+
|
| 207 |
+
x = x.type(self.clip_model.visual.conv1.weight.dtype)
|
| 208 |
+
x = stem(self.clip_model.visual, x)
|
| 209 |
+
x1 = self.layer1(x)
|
| 210 |
+
x2 = self.layer2(x1)
|
| 211 |
+
x3 = self.layer3(x2)
|
| 212 |
+
x4 = self.layer4(x3)
|
| 213 |
+
y = self.att_pool2d(x4)
|
| 214 |
+
|
| 215 |
+
# fc_features len: 1 ,feature_maps len: 5
|
| 216 |
+
return y, [x, x1, x2, x3, x4]
|
| 217 |
+
|
| 218 |
+
def forward(self, x) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 219 |
+
if self.clip_model_name.startswith("ViT"):
|
| 220 |
+
fc_feat, visual_feat_maps = self._forward_vit(x)
|
| 221 |
+
if self.clip_model_name.startswith("RN"):
|
| 222 |
+
fc_feat, visual_feat_maps = self._forward_resnet(x)
|
| 223 |
+
|
| 224 |
+
return fc_feat, visual_feat_maps
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class CLIPVisualLossWrapper(nn.Module):
|
| 228 |
+
"""
|
| 229 |
+
Visual Feature Loss + FC loss
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
def __init__(
|
| 233 |
+
self,
|
| 234 |
+
visual_encoder: nn.Module,
|
| 235 |
+
feats_loss_type: str = None,
|
| 236 |
+
feats_loss_weights: List[float] = None,
|
| 237 |
+
fc_loss_weight: float = None,
|
| 238 |
+
):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.visual_encoder = visual_encoder
|
| 241 |
+
self.feats_loss_weights = feats_loss_weights
|
| 242 |
+
self.fc_loss_weight = fc_loss_weight
|
| 243 |
+
|
| 244 |
+
self.layer_criterion = layer_wise_distance(feats_loss_type)
|
| 245 |
+
|
| 246 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
| 247 |
+
x_fc_feature, x_feat_maps = self.visual_encoder(x)
|
| 248 |
+
y_fc_feature, y_feat_maps = self.visual_encoder(y)
|
| 249 |
+
|
| 250 |
+
# visual feature loss
|
| 251 |
+
if sum(self.feats_loss_weights) == 0:
|
| 252 |
+
feats_loss_list = [torch.tensor(0, device=x.device)]
|
| 253 |
+
else:
|
| 254 |
+
feats_loss = self.layer_criterion(x_feat_maps, y_feat_maps, self.visual_encoder.clip_model_name)
|
| 255 |
+
feats_loss_list = []
|
| 256 |
+
for layer, w in enumerate(self.feats_loss_weights):
|
| 257 |
+
if w:
|
| 258 |
+
feats_loss_list.append(feats_loss[layer] * w)
|
| 259 |
+
|
| 260 |
+
# visual fc loss, default: cosine similarity
|
| 261 |
+
if self.fc_loss_weight == 0:
|
| 262 |
+
fc_loss = torch.tensor(0, device=x.device)
|
| 263 |
+
else:
|
| 264 |
+
fc_loss = (1 - torch.cosine_similarity(x_fc_feature, y_fc_feature, dim=1)).mean()
|
| 265 |
+
fc_loss = fc_loss * self.fc_loss_weight
|
| 266 |
+
|
| 267 |
+
return fc_loss, feats_loss_list
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
#################################################################################
|
| 271 |
+
# layer wise metric #
|
| 272 |
+
#################################################################################
|
| 273 |
+
|
| 274 |
+
def layer_wise_distance(metric_name: str):
|
| 275 |
+
return {
|
| 276 |
+
"l1": l1_layer_wise,
|
| 277 |
+
"l2": l2_layer_wise,
|
| 278 |
+
"cosine": cosine_layer_wise
|
| 279 |
+
}.get(metric_name.lower())
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def l2_layer_wise(x_features, y_features, clip_model_name):
|
| 283 |
+
return [
|
| 284 |
+
torch.square(x_conv - y_conv).mean()
|
| 285 |
+
for x_conv, y_conv in zip(x_features, y_features)
|
| 286 |
+
]
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def l1_layer_wise(x_features, y_features, clip_model_name):
|
| 290 |
+
return [
|
| 291 |
+
torch.abs(x_conv - y_conv).mean()
|
| 292 |
+
for x_conv, y_conv in zip(x_features, y_features)
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def cosine_layer_wise(x_features, y_features, clip_model_name):
|
| 297 |
+
if clip_model_name.startswith("RN"):
|
| 298 |
+
return [
|
| 299 |
+
(1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean()
|
| 300 |
+
for x_conv, y_conv in zip(x_features, y_features)
|
| 301 |
+
]
|
| 302 |
+
return [
|
| 303 |
+
(1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean()
|
| 304 |
+
for x_conv, y_conv in zip(x_features, y_features)
|
| 305 |
+
]
|
libs/metric/lpips_origin/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .lpips import LPIPS
|
| 2 |
+
|
| 3 |
+
__all__ = ['LPIPS']
|
libs/metric/lpips_origin/lpips.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from . import pretrained_networks as pretrained_torch_models
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def spatial_average(x, keepdim=True):
|
| 12 |
+
return x.mean([2, 3], keepdim=keepdim)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def upsample(x):
|
| 16 |
+
return nn.Upsample(size=x.shape[2:], mode='bilinear', align_corners=False)(x)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def normalize_tensor(in_feat, eps=1e-10):
|
| 20 |
+
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
|
| 21 |
+
return in_feat / (norm_factor + eps)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Learned perceptual metric
|
| 25 |
+
class LPIPS(nn.Module):
|
| 26 |
+
|
| 27 |
+
def __init__(self,
|
| 28 |
+
pretrained=True,
|
| 29 |
+
net='alex',
|
| 30 |
+
version='0.1',
|
| 31 |
+
lpips=True,
|
| 32 |
+
spatial=False,
|
| 33 |
+
pnet_rand=False,
|
| 34 |
+
pnet_tune=False,
|
| 35 |
+
use_dropout=True,
|
| 36 |
+
model_path=None,
|
| 37 |
+
eval_mode=True,
|
| 38 |
+
verbose=True):
|
| 39 |
+
""" Initializes a perceptual loss torch.nn.Module
|
| 40 |
+
|
| 41 |
+
Parameters (default listed first)
|
| 42 |
+
---------------------------------
|
| 43 |
+
lpips : bool
|
| 44 |
+
[True] use linear layers on top of base/trunk network
|
| 45 |
+
[False] means no linear layers; each layer is averaged together
|
| 46 |
+
pretrained : bool
|
| 47 |
+
This flag controls the linear layers, which are only in effect when lpips=True above
|
| 48 |
+
[True] means linear layers are calibrated with human perceptual judgments
|
| 49 |
+
[False] means linear layers are randomly initialized
|
| 50 |
+
pnet_rand : bool
|
| 51 |
+
[False] means trunk loaded with ImageNet classification weights
|
| 52 |
+
[True] means randomly initialized trunk
|
| 53 |
+
net : str
|
| 54 |
+
['alex','vgg','squeeze'] are the base/trunk networks available
|
| 55 |
+
version : str
|
| 56 |
+
['v0.1'] is the default and latest
|
| 57 |
+
['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
|
| 58 |
+
model_path : 'str'
|
| 59 |
+
[None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
|
| 60 |
+
|
| 61 |
+
The following parameters should only be changed if training the network:
|
| 62 |
+
|
| 63 |
+
eval_mode : bool
|
| 64 |
+
[True] is for test mode (default)
|
| 65 |
+
[False] is for training mode
|
| 66 |
+
pnet_tune
|
| 67 |
+
[False] keep base/trunk frozen
|
| 68 |
+
[True] tune the base/trunk network
|
| 69 |
+
use_dropout : bool
|
| 70 |
+
[True] to use dropout when training linear layers
|
| 71 |
+
[False] for no dropout when training linear layers
|
| 72 |
+
"""
|
| 73 |
+
super(LPIPS, self).__init__()
|
| 74 |
+
if verbose:
|
| 75 |
+
print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' %
|
| 76 |
+
('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
|
| 77 |
+
|
| 78 |
+
self.pnet_type = net
|
| 79 |
+
self.pnet_tune = pnet_tune
|
| 80 |
+
self.pnet_rand = pnet_rand
|
| 81 |
+
self.spatial = spatial
|
| 82 |
+
self.lpips = lpips # false means baseline of just averaging all layers
|
| 83 |
+
self.version = version
|
| 84 |
+
self.scaling_layer = ScalingLayer()
|
| 85 |
+
|
| 86 |
+
if self.pnet_type in ['vgg', 'vgg16']:
|
| 87 |
+
net_type = pretrained_torch_models.vgg16
|
| 88 |
+
self.chns = [64, 128, 256, 512, 512]
|
| 89 |
+
elif self.pnet_type == 'alex':
|
| 90 |
+
net_type = pretrained_torch_models.alexnet
|
| 91 |
+
self.chns = [64, 192, 384, 256, 256]
|
| 92 |
+
elif self.pnet_type == 'squeeze':
|
| 93 |
+
net_type = pretrained_torch_models.squeezenet
|
| 94 |
+
self.chns = [64, 128, 256, 384, 384, 512, 512]
|
| 95 |
+
self.L = len(self.chns)
|
| 96 |
+
|
| 97 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
| 98 |
+
|
| 99 |
+
if lpips:
|
| 100 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
| 101 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
| 102 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
| 103 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
| 104 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
| 105 |
+
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
| 106 |
+
if self.pnet_type == 'squeeze': # 7 layers for squeezenet
|
| 107 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
| 108 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
| 109 |
+
self.lins += [self.lin5, self.lin6]
|
| 110 |
+
self.lins = nn.ModuleList(self.lins)
|
| 111 |
+
|
| 112 |
+
if pretrained:
|
| 113 |
+
if model_path is None:
|
| 114 |
+
model_path = os.path.join(
|
| 115 |
+
os.path.dirname(os.path.abspath(__file__)),
|
| 116 |
+
f"weights/v{version}/{net}.pth"
|
| 117 |
+
)
|
| 118 |
+
if verbose:
|
| 119 |
+
print('Loading model from: %s' % model_path)
|
| 120 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
|
| 121 |
+
|
| 122 |
+
if eval_mode:
|
| 123 |
+
self.eval()
|
| 124 |
+
|
| 125 |
+
def forward(self, in0, in1, return_per_layer=False, normalize=False):
|
| 126 |
+
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, 1]
|
| 127 |
+
in0 = 2 * in0 - 1
|
| 128 |
+
in1 = 2 * in1 - 1
|
| 129 |
+
|
| 130 |
+
# Noting: v0.0 - original release had a bug, where input was not scaled
|
| 131 |
+
if self.version == '0.1':
|
| 132 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1))
|
| 133 |
+
else:
|
| 134 |
+
in0_input, in1_input = in0, in1
|
| 135 |
+
|
| 136 |
+
# model forward
|
| 137 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
| 138 |
+
|
| 139 |
+
feats0, feats1, diffs = {}, {}, {}
|
| 140 |
+
for kk in range(self.L):
|
| 141 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
| 142 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
| 143 |
+
|
| 144 |
+
if self.lpips:
|
| 145 |
+
if self.spatial:
|
| 146 |
+
res = [upsample(self.lins[kk](diffs[kk])) for kk in range(self.L)]
|
| 147 |
+
else:
|
| 148 |
+
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
|
| 149 |
+
else:
|
| 150 |
+
if self.spatial:
|
| 151 |
+
res = [upsample(diffs[kk].sum(dim=1, keepdim=True)) for kk in range(self.L)]
|
| 152 |
+
else:
|
| 153 |
+
res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
|
| 154 |
+
|
| 155 |
+
loss = sum(res)
|
| 156 |
+
|
| 157 |
+
if return_per_layer:
|
| 158 |
+
return loss, res
|
| 159 |
+
else:
|
| 160 |
+
return loss
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class ScalingLayer(nn.Module):
|
| 164 |
+
def __init__(self):
|
| 165 |
+
super(ScalingLayer, self).__init__()
|
| 166 |
+
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
| 167 |
+
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
| 168 |
+
|
| 169 |
+
def forward(self, inp):
|
| 170 |
+
return (inp - self.shift) / self.scale
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class NetLinLayer(nn.Module):
|
| 174 |
+
"""A single linear layer which does a 1x1 conv"""
|
| 175 |
+
|
| 176 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
| 177 |
+
super(NetLinLayer, self).__init__()
|
| 178 |
+
|
| 179 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
| 180 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
| 181 |
+
self.model = nn.Sequential(*layers)
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
return self.model(x)
|
libs/metric/lpips_origin/pretrained_networks.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import namedtuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision.models as tv_models
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class squeezenet(torch.nn.Module):
|
| 8 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
| 9 |
+
super(squeezenet, self).__init__()
|
| 10 |
+
pretrained_features = tv_models.squeezenet1_1(weights=pretrained).features
|
| 11 |
+
self.slice1 = torch.nn.Sequential()
|
| 12 |
+
self.slice2 = torch.nn.Sequential()
|
| 13 |
+
self.slice3 = torch.nn.Sequential()
|
| 14 |
+
self.slice4 = torch.nn.Sequential()
|
| 15 |
+
self.slice5 = torch.nn.Sequential()
|
| 16 |
+
self.slice6 = torch.nn.Sequential()
|
| 17 |
+
self.slice7 = torch.nn.Sequential()
|
| 18 |
+
self.N_slices = 7
|
| 19 |
+
for x in range(2):
|
| 20 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
| 21 |
+
for x in range(2, 5):
|
| 22 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
| 23 |
+
for x in range(5, 8):
|
| 24 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
| 25 |
+
for x in range(8, 10):
|
| 26 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
| 27 |
+
for x in range(10, 11):
|
| 28 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
| 29 |
+
for x in range(11, 12):
|
| 30 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
| 31 |
+
for x in range(12, 13):
|
| 32 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
| 33 |
+
if not requires_grad:
|
| 34 |
+
for param in self.parameters():
|
| 35 |
+
param.requires_grad = False
|
| 36 |
+
|
| 37 |
+
def forward(self, X):
|
| 38 |
+
h = self.slice1(X)
|
| 39 |
+
h_relu1 = h
|
| 40 |
+
h = self.slice2(h)
|
| 41 |
+
h_relu2 = h
|
| 42 |
+
h = self.slice3(h)
|
| 43 |
+
h_relu3 = h
|
| 44 |
+
h = self.slice4(h)
|
| 45 |
+
h_relu4 = h
|
| 46 |
+
h = self.slice5(h)
|
| 47 |
+
h_relu5 = h
|
| 48 |
+
h = self.slice6(h)
|
| 49 |
+
h_relu6 = h
|
| 50 |
+
h = self.slice7(h)
|
| 51 |
+
h_relu7 = h
|
| 52 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])
|
| 53 |
+
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
|
| 54 |
+
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class alexnet(torch.nn.Module):
|
| 59 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
| 60 |
+
super(alexnet, self).__init__()
|
| 61 |
+
weights = tv_models.AlexNet_Weights.IMAGENET1K_V1 if pretrained else None
|
| 62 |
+
alexnet_pretrained_features = tv_models.alexnet(weights=weights).features
|
| 63 |
+
self.slice1 = torch.nn.Sequential()
|
| 64 |
+
self.slice2 = torch.nn.Sequential()
|
| 65 |
+
self.slice3 = torch.nn.Sequential()
|
| 66 |
+
self.slice4 = torch.nn.Sequential()
|
| 67 |
+
self.slice5 = torch.nn.Sequential()
|
| 68 |
+
self.N_slices = 5
|
| 69 |
+
for x in range(2):
|
| 70 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
| 71 |
+
for x in range(2, 5):
|
| 72 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
| 73 |
+
for x in range(5, 8):
|
| 74 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
| 75 |
+
for x in range(8, 10):
|
| 76 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
| 77 |
+
for x in range(10, 12):
|
| 78 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
| 79 |
+
|
| 80 |
+
if not requires_grad:
|
| 81 |
+
for param in self.parameters():
|
| 82 |
+
param.requires_grad = False
|
| 83 |
+
|
| 84 |
+
def forward(self, X):
|
| 85 |
+
h = self.slice1(X)
|
| 86 |
+
h_relu1 = h
|
| 87 |
+
h = self.slice2(h)
|
| 88 |
+
h_relu2 = h
|
| 89 |
+
h = self.slice3(h)
|
| 90 |
+
h_relu3 = h
|
| 91 |
+
h = self.slice4(h)
|
| 92 |
+
h_relu4 = h
|
| 93 |
+
h = self.slice5(h)
|
| 94 |
+
h_relu5 = h
|
| 95 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
| 96 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
| 97 |
+
|
| 98 |
+
return out
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class vgg16(torch.nn.Module):
|
| 102 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
| 103 |
+
super(vgg16, self).__init__()
|
| 104 |
+
weights = tv_models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None
|
| 105 |
+
vgg_pretrained_features = tv_models.vgg16(weights=weights).features
|
| 106 |
+
self.slice1 = torch.nn.Sequential()
|
| 107 |
+
self.slice2 = torch.nn.Sequential()
|
| 108 |
+
self.slice3 = torch.nn.Sequential()
|
| 109 |
+
self.slice4 = torch.nn.Sequential()
|
| 110 |
+
self.slice5 = torch.nn.Sequential()
|
| 111 |
+
self.N_slices = 5
|
| 112 |
+
for x in range(4):
|
| 113 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 114 |
+
for x in range(4, 9):
|
| 115 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 116 |
+
for x in range(9, 16):
|
| 117 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 118 |
+
for x in range(16, 23):
|
| 119 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 120 |
+
for x in range(23, 30):
|
| 121 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
| 122 |
+
|
| 123 |
+
if not requires_grad:
|
| 124 |
+
for param in self.parameters():
|
| 125 |
+
param.requires_grad = False
|
| 126 |
+
|
| 127 |
+
def forward(self, X):
|
| 128 |
+
h = self.slice1(X)
|
| 129 |
+
h_relu1_2 = h
|
| 130 |
+
h = self.slice2(h)
|
| 131 |
+
h_relu2_2 = h
|
| 132 |
+
h = self.slice3(h)
|
| 133 |
+
h_relu3_3 = h
|
| 134 |
+
h = self.slice4(h)
|
| 135 |
+
h_relu4_3 = h
|
| 136 |
+
h = self.slice5(h)
|
| 137 |
+
h_relu5_3 = h
|
| 138 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
| 139 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
| 140 |
+
|
| 141 |
+
return out
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class resnet(torch.nn.Module):
|
| 145 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
| 146 |
+
super(resnet, self).__init__()
|
| 147 |
+
|
| 148 |
+
if num == 18:
|
| 149 |
+
weights = tv_models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
|
| 150 |
+
self.net = tv_models.resnet18(weights=weights)
|
| 151 |
+
elif num == 34:
|
| 152 |
+
weights = tv_models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None
|
| 153 |
+
self.net = tv_models.resnet34(weights=weights)
|
| 154 |
+
elif num == 50:
|
| 155 |
+
weights = tv_models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
|
| 156 |
+
self.net = tv_models.resnet50(weights=weights)
|
| 157 |
+
elif num == 101:
|
| 158 |
+
weights = tv_models.ResNet101_Weights.IMAGENET1K_V2 if pretrained else None
|
| 159 |
+
self.net = tv_models.resnet101(weights=weights)
|
| 160 |
+
elif num == 152:
|
| 161 |
+
weights = tv_models.ResNet152_Weights.IMAGENET1K_V2 if pretrained else None
|
| 162 |
+
self.net = tv_models.resnet152(weights=weights)
|
| 163 |
+
self.N_slices = 5
|
| 164 |
+
|
| 165 |
+
if not requires_grad:
|
| 166 |
+
for param in self.net.parameters():
|
| 167 |
+
param.requires_grad = False
|
| 168 |
+
|
| 169 |
+
self.conv1 = self.net.conv1
|
| 170 |
+
self.bn1 = self.net.bn1
|
| 171 |
+
self.relu = self.net.relu
|
| 172 |
+
self.maxpool = self.net.maxpool
|
| 173 |
+
self.layer1 = self.net.layer1
|
| 174 |
+
self.layer2 = self.net.layer2
|
| 175 |
+
self.layer3 = self.net.layer3
|
| 176 |
+
self.layer4 = self.net.layer4
|
| 177 |
+
|
| 178 |
+
def forward(self, X):
|
| 179 |
+
h = self.conv1(X)
|
| 180 |
+
h = self.bn1(h)
|
| 181 |
+
h = self.relu(h)
|
| 182 |
+
h_relu1 = h
|
| 183 |
+
h = self.maxpool(h)
|
| 184 |
+
h = self.layer1(h)
|
| 185 |
+
h_conv2 = h
|
| 186 |
+
h = self.layer2(h)
|
| 187 |
+
h_conv3 = h
|
| 188 |
+
h = self.layer3(h)
|
| 189 |
+
h_conv4 = h
|
| 190 |
+
h = self.layer4(h)
|
| 191 |
+
h_conv5 = h
|
| 192 |
+
|
| 193 |
+
outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
|
| 194 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
| 195 |
+
|
| 196 |
+
return out
|
libs/metric/lpips_origin/weights/v0.1/alex.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
|
| 3 |
+
size 6009
|
libs/metric/lpips_origin/weights/v0.1/squeeze.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
|
| 3 |
+
size 10811
|
libs/metric/lpips_origin/weights/v0.1/vgg.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
| 3 |
+
size 7289
|
libs/metric/piq/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
# install: pip install piq
|
| 7 |
+
# repo: https://github.com/photosynthesis-team/piq
|
libs/metric/piq/functional/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import ifftshift, get_meshgrid, similarity_map, gradient_map, pow_for_complex, crop_patches
|
| 2 |
+
from .colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq, rgb2lhm
|
| 3 |
+
from .filters import haar_filter, hann_filter, scharr_filter, prewitt_filter, gaussian_filter
|
| 4 |
+
from .filters import binomial_filter1d, average_filter2d
|
| 5 |
+
from .layers import L2Pool2d
|
| 6 |
+
from .resize import imresize
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map', 'pow_for_complex', 'crop_patches',
|
| 10 |
+
'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq', 'rgb2lhm',
|
| 11 |
+
'haar_filter', 'hann_filter', 'scharr_filter', 'prewitt_filter', 'gaussian_filter',
|
| 12 |
+
'binomial_filter1d', 'average_filter2d',
|
| 13 |
+
'L2Pool2d',
|
| 14 |
+
'imresize',
|
| 15 |
+
]
|
libs/metric/piq/functional/base.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""General purpose functions"""
|
| 2 |
+
from typing import Tuple, Union, Optional
|
| 3 |
+
import torch
|
| 4 |
+
from ..utils import _parse_version
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def ifftshift(x: torch.Tensor) -> torch.Tensor:
|
| 8 |
+
r""" Similar to np.fft.ifftshift but applies to PyTorch Tensors"""
|
| 9 |
+
shift = [-(ax // 2) for ax in x.size()]
|
| 10 |
+
return torch.roll(x, shift, tuple(range(len(shift))))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_meshgrid(size: Tuple[int, int], device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
| 14 |
+
r"""Return coordinate grid matrices centered at zero point.
|
| 15 |
+
Args:
|
| 16 |
+
size: Shape of meshgrid to create
|
| 17 |
+
device: device to use for creation
|
| 18 |
+
dtype: dtype to use for creation
|
| 19 |
+
Returns:
|
| 20 |
+
Meshgrid of size on device with dtype values.
|
| 21 |
+
"""
|
| 22 |
+
if size[0] % 2:
|
| 23 |
+
# Odd
|
| 24 |
+
x = torch.arange(-(size[0] - 1) / 2, size[0] / 2, device=device, dtype=dtype) / (size[0] - 1)
|
| 25 |
+
else:
|
| 26 |
+
# Even
|
| 27 |
+
x = torch.arange(- size[0] / 2, size[0] / 2, device=device, dtype=dtype) / size[0]
|
| 28 |
+
|
| 29 |
+
if size[1] % 2:
|
| 30 |
+
# Odd
|
| 31 |
+
y = torch.arange(-(size[1] - 1) / 2, size[1] / 2, device=device, dtype=dtype) / (size[1] - 1)
|
| 32 |
+
else:
|
| 33 |
+
# Even
|
| 34 |
+
y = torch.arange(- size[1] / 2, size[1] / 2, device=device, dtype=dtype) / size[1]
|
| 35 |
+
# Use indexing param depending on torch version
|
| 36 |
+
recommended_torch_version = _parse_version("1.10.0")
|
| 37 |
+
torch_version = _parse_version(torch.__version__)
|
| 38 |
+
if len(torch_version) > 0 and torch_version >= recommended_torch_version:
|
| 39 |
+
return torch.meshgrid(x, y, indexing='ij')
|
| 40 |
+
return torch.meshgrid(x, y)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def similarity_map(map_x: torch.Tensor, map_y: torch.Tensor, constant: float, alpha: float = 0.0) -> torch.Tensor:
|
| 44 |
+
r""" Compute similarity_map between two tensors using Dice-like equation.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
map_x: Tensor with map to be compared
|
| 48 |
+
map_y: Tensor with map to be compared
|
| 49 |
+
constant: Used for numerical stability
|
| 50 |
+
alpha: Masking coefficient. Subtracts - `alpha` * map_x * map_y from denominator and nominator
|
| 51 |
+
"""
|
| 52 |
+
return (2.0 * map_x * map_y - alpha * map_x * map_y + constant) / \
|
| 53 |
+
(map_x ** 2 + map_y ** 2 - alpha * map_x * map_y + constant)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def gradient_map(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
r""" Compute gradient map for a given tensor and stack of kernels.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
x: Tensor with shape (N, C, H, W).
|
| 61 |
+
kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W)
|
| 62 |
+
Returns:
|
| 63 |
+
Gradients of x per-channel with shape (N, C, H, W)
|
| 64 |
+
"""
|
| 65 |
+
padding = kernels.size(-1) // 2
|
| 66 |
+
grads = torch.nn.functional.conv2d(x, kernels, padding=padding)
|
| 67 |
+
|
| 68 |
+
return torch.sqrt(torch.sum(grads ** 2, dim=-3, keepdim=True))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def pow_for_complex(base: torch.Tensor, exp: Union[int, float]) -> torch.Tensor:
|
| 72 |
+
r""" Takes the power of each element in a 4D tensor with negative values or 5D tensor with complex values.
|
| 73 |
+
Complex numbers are represented by modulus and argument: r * \exp(i * \phi).
|
| 74 |
+
|
| 75 |
+
It will likely to be redundant with introduction of torch.ComplexTensor.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
base: Tensor with shape (N, C, H, W) or (N, C, H, W, 2).
|
| 79 |
+
exp: Exponent
|
| 80 |
+
Returns:
|
| 81 |
+
Complex tensor with shape (N, C, H, W, 2).
|
| 82 |
+
"""
|
| 83 |
+
if base.dim() == 4:
|
| 84 |
+
x_complex_r = base.abs()
|
| 85 |
+
x_complex_phi = torch.atan2(torch.zeros_like(base), base)
|
| 86 |
+
elif base.dim() == 5 and base.size(-1) == 2:
|
| 87 |
+
x_complex_r = base.pow(2).sum(dim=-1).sqrt()
|
| 88 |
+
x_complex_phi = torch.atan2(base[..., 1], base[..., 0])
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f'Expected real or complex tensor, got {base.size()}')
|
| 91 |
+
|
| 92 |
+
x_complex_pow_r = x_complex_r ** exp
|
| 93 |
+
x_complex_pow_phi = x_complex_phi * exp
|
| 94 |
+
x_real_pow = x_complex_pow_r * torch.cos(x_complex_pow_phi)
|
| 95 |
+
x_imag_pow = x_complex_pow_r * torch.sin(x_complex_pow_phi)
|
| 96 |
+
return torch.stack((x_real_pow, x_imag_pow), dim=-1)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def crop_patches(x: torch.Tensor, size=64, stride=32) -> torch.Tensor:
|
| 100 |
+
r"""Crop tensor with images into small patches
|
| 101 |
+
Args:
|
| 102 |
+
x: Tensor with shape (N, C, H, W), expected to be images-like entities
|
| 103 |
+
size: Size of a square patch
|
| 104 |
+
stride: Step between patches
|
| 105 |
+
"""
|
| 106 |
+
assert (x.shape[2] >= size) and (x.shape[3] >= size), \
|
| 107 |
+
f"Images must be bigger than patch size. Got ({x.shape[2], x.shape[3]}) and ({size}, {size})"
|
| 108 |
+
channels = x.shape[1]
|
| 109 |
+
patches = x.unfold(1, channels, channels).unfold(2, size, stride).unfold(3, size, stride)
|
| 110 |
+
patches = patches.reshape(-1, channels, size, size)
|
| 111 |
+
return patches
|
libs/metric/piq/functional/colour_conversion.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""Colour space conversion functions"""
|
| 2 |
+
from typing import Union, Dict
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def rgb2lmn(x: torch.Tensor) -> torch.Tensor:
|
| 7 |
+
r"""Convert a batch of RGB images to a batch of LMN images
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
x: Batch of images with shape (N, 3, H, W). RGB colour space.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
Batch of images with shape (N, 3, H, W). LMN colour space.
|
| 14 |
+
"""
|
| 15 |
+
weights_rgb_to_lmn = torch.tensor([[0.06, 0.63, 0.27],
|
| 16 |
+
[0.30, 0.04, -0.35],
|
| 17 |
+
[0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t()
|
| 18 |
+
x_lmn = torch.matmul(x.permute(0, 2, 3, 1), weights_rgb_to_lmn).permute(0, 3, 1, 2)
|
| 19 |
+
return x_lmn
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def rgb2xyz(x: torch.Tensor) -> torch.Tensor:
|
| 23 |
+
r"""Convert a batch of RGB images to a batch of XYZ images
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
x: Batch of images with shape (N, 3, H, W). RGB colour space.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Batch of images with shape (N, 3, H, W). XYZ colour space.
|
| 30 |
+
"""
|
| 31 |
+
mask_below = (x <= 0.04045).type(x.dtype)
|
| 32 |
+
mask_above = (x > 0.04045).type(x.dtype)
|
| 33 |
+
|
| 34 |
+
tmp = x / 12.92 * mask_below + torch.pow((x + 0.055) / 1.055, 2.4) * mask_above
|
| 35 |
+
|
| 36 |
+
weights_rgb_to_xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
|
| 37 |
+
[0.2126729, 0.7151522, 0.0721750],
|
| 38 |
+
[0.0193339, 0.1191920, 0.9503041]], dtype=x.dtype, device=x.device)
|
| 39 |
+
|
| 40 |
+
x_xyz = torch.matmul(tmp.permute(0, 2, 3, 1), weights_rgb_to_xyz.t()).permute(0, 3, 1, 2)
|
| 41 |
+
return x_xyz
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor:
|
| 45 |
+
r"""Convert a batch of XYZ images to a batch of LAB images
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
x: Batch of images with shape (N, 3, H, W). XYZ colour space.
|
| 49 |
+
illuminant: {“A”, “D50”, “D55”, “D65”, “D75”, “E”}, optional. The name of the illuminant.
|
| 50 |
+
observer: {“2”, “10”}, optional. The aperture angle of the observer.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Batch of images with shape (N, 3, H, W). LAB colour space.
|
| 54 |
+
"""
|
| 55 |
+
epsilon = 0.008856
|
| 56 |
+
kappa = 903.3
|
| 57 |
+
illuminants: Dict[str, Dict] = \
|
| 58 |
+
{"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
|
| 59 |
+
'10': (1.111420406956693, 1, 0.3519978321919493)},
|
| 60 |
+
"D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
|
| 61 |
+
'10': (0.9672062750333777, 1, 0.8142801513128616)},
|
| 62 |
+
"D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
|
| 63 |
+
'10': (0.9579665682254781, 1, 0.9092525159847462)},
|
| 64 |
+
"D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
|
| 65 |
+
'10': (0.94809667673716, 1, 1.0730513595166162)},
|
| 66 |
+
"D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
|
| 67 |
+
'10': (0.9441713925645873, 1, 1.2064272211720228)},
|
| 68 |
+
"E": {'2': (1.0, 1.0, 1.0),
|
| 69 |
+
'10': (1.0, 1.0, 1.0)}}
|
| 70 |
+
|
| 71 |
+
illuminants_to_use = torch.tensor(illuminants[illuminant][observer],
|
| 72 |
+
dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
|
| 73 |
+
|
| 74 |
+
tmp = x / illuminants_to_use
|
| 75 |
+
|
| 76 |
+
mask_below = (tmp <= epsilon).type(x.dtype)
|
| 77 |
+
mask_above = (tmp > epsilon).type(x.dtype)
|
| 78 |
+
tmp = torch.pow(tmp, 1. / 3.) * mask_above + (kappa * tmp + 16.) / 116. * mask_below
|
| 79 |
+
|
| 80 |
+
weights_xyz_to_lab = torch.tensor([[0, 116., 0],
|
| 81 |
+
[500., -500., 0],
|
| 82 |
+
[0, 200., -200.]], dtype=x.dtype, device=x.device)
|
| 83 |
+
bias_xyz_to_lab = torch.tensor([-16., 0., 0.], dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
|
| 84 |
+
|
| 85 |
+
x_lab = torch.matmul(tmp.permute(0, 2, 3, 1), weights_xyz_to_lab.t()).permute(0, 3, 1, 2) + bias_xyz_to_lab
|
| 86 |
+
return x_lab
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor:
|
| 90 |
+
r"""Convert a batch of RGB images to a batch of LAB images
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
x: Batch of images with shape (N, 3, H, W). RGB colour space.
|
| 94 |
+
data_range: dynamic range of the input image.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Batch of images with shape (N, 3, H, W). LAB colour space.
|
| 98 |
+
"""
|
| 99 |
+
return xyz2lab(rgb2xyz(x / float(data_range)))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def rgb2yiq(x: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
r"""Convert a batch of RGB images to a batch of YIQ images
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
x: Batch of images with shape (N, 3, H, W). RGB colour space.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Batch of images with shape (N, 3, H, W). YIQ colour space.
|
| 110 |
+
"""
|
| 111 |
+
yiq_weights = torch.tensor([
|
| 112 |
+
[0.299, 0.587, 0.114],
|
| 113 |
+
[0.5959, -0.2746, -0.3213],
|
| 114 |
+
[0.2115, -0.5227, 0.3112]], dtype=x.dtype, device=x.device).t()
|
| 115 |
+
x_yiq = torch.matmul(x.permute(0, 2, 3, 1), yiq_weights).permute(0, 3, 1, 2)
|
| 116 |
+
return x_yiq
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def rgb2lhm(x: torch.Tensor) -> torch.Tensor:
|
| 120 |
+
r"""Convert a batch of RGB images to a batch of LHM images
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
x: Batch of images with shape (N, 3, H, W). RGB colour space.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Batch of images with shape (N, 3, H, W). LHM colour space.
|
| 127 |
+
|
| 128 |
+
Reference:
|
| 129 |
+
https://arxiv.org/pdf/1608.07433.pdf
|
| 130 |
+
"""
|
| 131 |
+
lhm_weights = torch.tensor([
|
| 132 |
+
[0.2989, 0.587, 0.114],
|
| 133 |
+
[0.3, 0.04, -0.35],
|
| 134 |
+
[0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t()
|
| 135 |
+
x_lhm = torch.matmul(x.permute(0, 2, 3, 1), lhm_weights).permute(0, 3, 1, 2)
|
| 136 |
+
return x_lhm
|
libs/metric/piq/functional/filters.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""Filters for gradient computation, bluring, etc."""
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def haar_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
| 8 |
+
r"""Creates Haar kernel
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
kernel_size: size of the kernel
|
| 12 |
+
device: target device for kernel generation
|
| 13 |
+
dtype: target data type for kernel generation
|
| 14 |
+
Returns:
|
| 15 |
+
kernel: Tensor with shape (1, kernel_size, kernel_size)
|
| 16 |
+
"""
|
| 17 |
+
kernel = torch.ones((kernel_size, kernel_size), device=device, dtype=dtype) / kernel_size
|
| 18 |
+
kernel[kernel_size // 2:, :] = - kernel[kernel_size // 2:, :]
|
| 19 |
+
return kernel.unsqueeze(0)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def hann_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
| 23 |
+
r"""Creates Hann kernel
|
| 24 |
+
Args:
|
| 25 |
+
kernel_size: size of the kernel
|
| 26 |
+
device: target device for kernel generation
|
| 27 |
+
dtype: target data type for kernel generation
|
| 28 |
+
Returns:
|
| 29 |
+
kernel: Tensor with shape (1, kernel_size, kernel_size)
|
| 30 |
+
"""
|
| 31 |
+
# Take bigger window and drop borders
|
| 32 |
+
window = torch.hann_window(kernel_size + 2, periodic=False, device=device, dtype=dtype)[1:-1]
|
| 33 |
+
kernel = window[:, None] * window[None, :]
|
| 34 |
+
# Normalize and reshape kernel
|
| 35 |
+
return kernel.view(1, kernel_size, kernel_size) / kernel.sum()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def gaussian_filter(kernel_size: int, sigma: float, device: Optional[str] = None,
|
| 39 |
+
dtype: Optional[type] = None) -> torch.Tensor:
|
| 40 |
+
r"""Returns 2D Gaussian kernel N(0,`sigma`^2)
|
| 41 |
+
Args:
|
| 42 |
+
size: Size of the kernel
|
| 43 |
+
sigma: Std of the distribution
|
| 44 |
+
device: target device for kernel generation
|
| 45 |
+
dtype: target data type for kernel generation
|
| 46 |
+
Returns:
|
| 47 |
+
gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size)
|
| 48 |
+
"""
|
| 49 |
+
coords = torch.arange(kernel_size, dtype=dtype, device=device)
|
| 50 |
+
coords -= (kernel_size - 1) / 2.
|
| 51 |
+
|
| 52 |
+
g = coords ** 2
|
| 53 |
+
g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp()
|
| 54 |
+
|
| 55 |
+
g /= g.sum()
|
| 56 |
+
return g.unsqueeze(0)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Gradient operator kernels
|
| 60 |
+
def scharr_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
| 61 |
+
r"""Utility function that returns a normalized 3x3 Scharr kernel in X direction
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
device: target device for kernel generation
|
| 65 |
+
dtype: target data type for kernel generation
|
| 66 |
+
Returns:
|
| 67 |
+
kernel: Tensor with shape (1, 3, 3)
|
| 68 |
+
"""
|
| 69 |
+
return torch.tensor([[[-3., 0., 3.], [-10., 0., 10.], [-3., 0., 3.]]], device=device, dtype=dtype) / 16
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def prewitt_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
| 73 |
+
r"""Utility function that returns a normalized 3x3 Prewitt kernel in X direction
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
device: target device for kernel generation
|
| 77 |
+
dtype: target data type for kernel generation
|
| 78 |
+
Returns:
|
| 79 |
+
kernel: Tensor with shape (1, 3, 3)"""
|
| 80 |
+
return torch.tensor([[[-1., 0., 1.], [-1., 0., 1.], [-1., 0., 1.]]], device=device, dtype=dtype) / 3
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def binomial_filter1d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
| 84 |
+
r"""Creates 1D normalized binomial filter
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
kernel_size (int): kernel size
|
| 88 |
+
device: target device for kernel generation
|
| 89 |
+
dtype: target data type for kernel generation
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Binomial kernel with shape (1, 1, kernel_size)
|
| 93 |
+
"""
|
| 94 |
+
kernel = np.poly1d([0.5, 0.5]) ** (kernel_size - 1)
|
| 95 |
+
return torch.tensor(kernel.c, dtype=dtype, device=device).view(1, 1, kernel_size)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def average_filter2d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
|
| 99 |
+
r"""Creates 2D normalized average filter
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
kernel_size (int): kernel size
|
| 103 |
+
device: target device for kernel generation
|
| 104 |
+
dtype: target data type for kernel generation
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
kernel: Tensor with shape (1, kernel_size, kernel_size)
|
| 108 |
+
"""
|
| 109 |
+
window = torch.ones(kernel_size, dtype=dtype, device=device) / kernel_size
|
| 110 |
+
kernel = window[:, None] * window[None, :]
|
| 111 |
+
return kernel.unsqueeze(0)
|
libs/metric/piq/functional/layers.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""Custom layers used in metrics computations"""
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from .filters import hann_filter
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class L2Pool2d(torch.nn.Module):
|
| 9 |
+
r"""Applies L2 pooling with Hann window of size 3x3
|
| 10 |
+
Args:
|
| 11 |
+
x: Tensor with shape (N, C, H, W)"""
|
| 12 |
+
EPS = 1e-12
|
| 13 |
+
|
| 14 |
+
def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None:
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.kernel_size = kernel_size
|
| 17 |
+
self.stride = stride
|
| 18 |
+
self.padding = padding
|
| 19 |
+
|
| 20 |
+
self.kernel: Optional[torch.Tensor] = None
|
| 21 |
+
|
| 22 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 23 |
+
if self.kernel is None:
|
| 24 |
+
C = x.size(1)
|
| 25 |
+
self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x)
|
| 26 |
+
|
| 27 |
+
out = torch.nn.functional.conv2d(
|
| 28 |
+
x ** 2, self.kernel,
|
| 29 |
+
stride=self.stride,
|
| 30 |
+
padding=self.padding,
|
| 31 |
+
groups=x.shape[1]
|
| 32 |
+
)
|
| 33 |
+
return (out + self.EPS).sqrt()
|
libs/metric/piq/functional/resize.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A standalone PyTorch implementation for fast and efficient bicubic resampling.
|
| 3 |
+
The resulting values are the same to MATLAB function imresize('bicubic').
|
| 4 |
+
## Author: Sanghyun Son
|
| 5 |
+
## Email: [email protected] (primary), [email protected] (secondary)
|
| 6 |
+
## Version: 1.2.0
|
| 7 |
+
## Last update: July 9th, 2020 (KST)
|
| 8 |
+
Dependency: torch
|
| 9 |
+
Example::
|
| 10 |
+
>>> import torch
|
| 11 |
+
>>> import core
|
| 12 |
+
>>> x = torch.arange(16).float().view(1, 1, 4, 4)
|
| 13 |
+
>>> y = core.imresize(x, sizes=(3, 3))
|
| 14 |
+
>>> print(y)
|
| 15 |
+
tensor([[[[ 0.7506, 2.1004, 3.4503],
|
| 16 |
+
[ 6.1505, 7.5000, 8.8499],
|
| 17 |
+
[11.5497, 12.8996, 14.2494]]]])
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
import typing
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from torch.nn import functional as F
|
| 25 |
+
|
| 26 |
+
__all__ = ['imresize']
|
| 27 |
+
|
| 28 |
+
_I = typing.Optional[int]
|
| 29 |
+
_D = typing.Optional[torch.dtype]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def nearest_contribution(x: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5))
|
| 34 |
+
cont = range_around_0.to(dtype=x.dtype)
|
| 35 |
+
return cont
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def linear_contribution(x: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
ax = x.abs()
|
| 40 |
+
range_01 = ax.le(1)
|
| 41 |
+
cont = (1 - ax) * range_01.to(dtype=x.dtype)
|
| 42 |
+
return cont
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:
|
| 46 |
+
ax = x.abs()
|
| 47 |
+
ax2 = ax * ax
|
| 48 |
+
ax3 = ax * ax2
|
| 49 |
+
|
| 50 |
+
range_01 = ax.le(1)
|
| 51 |
+
range_12 = torch.logical_and(ax.gt(1), ax.le(2))
|
| 52 |
+
|
| 53 |
+
cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
|
| 54 |
+
cont_01 = cont_01 * range_01.to(dtype=x.dtype)
|
| 55 |
+
|
| 56 |
+
cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
|
| 57 |
+
cont_12 = cont_12 * range_12.to(dtype=x.dtype)
|
| 58 |
+
|
| 59 |
+
cont = cont_01 + cont_12
|
| 60 |
+
return cont
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
|
| 64 |
+
range_3sigma = (x.abs() <= 3 * sigma + 1)
|
| 65 |
+
# Normalization will be done after
|
| 66 |
+
cont = torch.exp(-x.pow(2) / (2 * sigma ** 2))
|
| 67 |
+
cont = cont * range_3sigma.to(dtype=x.dtype)
|
| 68 |
+
return cont
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def discrete_kernel(
|
| 72 |
+
kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor:
|
| 73 |
+
'''
|
| 74 |
+
For downsampling with integer scale only.
|
| 75 |
+
'''
|
| 76 |
+
downsampling_factor = int(1 / scale)
|
| 77 |
+
if kernel == 'cubic':
|
| 78 |
+
kernel_size_orig = 4
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError('Pass!')
|
| 81 |
+
|
| 82 |
+
if antialiasing:
|
| 83 |
+
kernel_size = kernel_size_orig * downsampling_factor
|
| 84 |
+
else:
|
| 85 |
+
kernel_size = kernel_size_orig
|
| 86 |
+
|
| 87 |
+
if downsampling_factor % 2 == 0:
|
| 88 |
+
a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))
|
| 89 |
+
else:
|
| 90 |
+
kernel_size -= 1
|
| 91 |
+
a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))
|
| 92 |
+
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
r = torch.linspace(-a, a, steps=kernel_size)
|
| 95 |
+
k = cubic_contribution(r).view(-1, 1)
|
| 96 |
+
k = torch.matmul(k, k.t())
|
| 97 |
+
k /= k.sum()
|
| 98 |
+
|
| 99 |
+
return k
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def reflect_padding(
|
| 103 |
+
x: torch.Tensor,
|
| 104 |
+
dim: int,
|
| 105 |
+
pad_pre: int,
|
| 106 |
+
pad_post: int) -> torch.Tensor:
|
| 107 |
+
'''
|
| 108 |
+
Apply reflect padding to the given Tensor.
|
| 109 |
+
Note that it is slightly different from the PyTorch functional.pad,
|
| 110 |
+
where boundary elements are used only once.
|
| 111 |
+
Instead, we follow the MATLAB implementation
|
| 112 |
+
which uses boundary elements twice.
|
| 113 |
+
For example,
|
| 114 |
+
[a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
|
| 115 |
+
while our implementation yields [a, a, b, c, d, d].
|
| 116 |
+
'''
|
| 117 |
+
b, c, h, w = x.size()
|
| 118 |
+
if dim == 2 or dim == -2:
|
| 119 |
+
padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)
|
| 120 |
+
padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)
|
| 121 |
+
for p in range(pad_pre):
|
| 122 |
+
padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])
|
| 123 |
+
for p in range(pad_post):
|
| 124 |
+
padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])
|
| 125 |
+
else:
|
| 126 |
+
padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)
|
| 127 |
+
padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)
|
| 128 |
+
for p in range(pad_pre):
|
| 129 |
+
padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])
|
| 130 |
+
for p in range(pad_post):
|
| 131 |
+
padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])
|
| 132 |
+
|
| 133 |
+
return padding_buffer
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def padding(
|
| 137 |
+
x: torch.Tensor,
|
| 138 |
+
dim: int,
|
| 139 |
+
pad_pre: int,
|
| 140 |
+
pad_post: int,
|
| 141 |
+
padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:
|
| 142 |
+
if padding_type is None:
|
| 143 |
+
return x
|
| 144 |
+
elif padding_type == 'reflect':
|
| 145 |
+
x_pad = reflect_padding(x, dim, pad_pre, pad_post)
|
| 146 |
+
else:
|
| 147 |
+
raise ValueError('{} padding is not supported!'.format(padding_type))
|
| 148 |
+
|
| 149 |
+
return x_pad
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_padding(
|
| 153 |
+
base: torch.Tensor,
|
| 154 |
+
kernel_size: int,
|
| 155 |
+
x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
|
| 156 |
+
base = base.long()
|
| 157 |
+
r_min = base.min()
|
| 158 |
+
r_max = base.max() + kernel_size - 1
|
| 159 |
+
|
| 160 |
+
if r_min <= 0:
|
| 161 |
+
pad_pre = -r_min
|
| 162 |
+
pad_pre = pad_pre.item()
|
| 163 |
+
base += pad_pre
|
| 164 |
+
else:
|
| 165 |
+
pad_pre = 0
|
| 166 |
+
|
| 167 |
+
if r_max >= x_size:
|
| 168 |
+
pad_post = r_max - x_size + 1
|
| 169 |
+
pad_post = pad_post.item()
|
| 170 |
+
else:
|
| 171 |
+
pad_post = 0
|
| 172 |
+
|
| 173 |
+
return pad_pre, pad_post, base
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_weight(
|
| 177 |
+
dist: torch.Tensor,
|
| 178 |
+
kernel_size: int,
|
| 179 |
+
kernel: str = 'cubic',
|
| 180 |
+
sigma: float = 2.0,
|
| 181 |
+
antialiasing_factor: float = 1) -> torch.Tensor:
|
| 182 |
+
buffer_pos = dist.new_zeros(kernel_size, len(dist))
|
| 183 |
+
for idx, buffer_sub in enumerate(buffer_pos):
|
| 184 |
+
buffer_sub.copy_(dist - idx)
|
| 185 |
+
|
| 186 |
+
# Expand (downsampling) / Shrink (upsampling) the receptive field.
|
| 187 |
+
buffer_pos *= antialiasing_factor
|
| 188 |
+
if kernel == 'cubic':
|
| 189 |
+
weight = cubic_contribution(buffer_pos)
|
| 190 |
+
elif kernel == 'gaussian':
|
| 191 |
+
weight = gaussian_contribution(buffer_pos, sigma=sigma)
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError('{} kernel is not supported!'.format(kernel))
|
| 194 |
+
|
| 195 |
+
weight /= weight.sum(dim=0, keepdim=True)
|
| 196 |
+
return weight
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
|
| 200 |
+
# Resize height
|
| 201 |
+
if dim == 2 or dim == -2:
|
| 202 |
+
k = (kernel_size, 1)
|
| 203 |
+
h_out = x.size(-2) - kernel_size + 1
|
| 204 |
+
w_out = x.size(-1)
|
| 205 |
+
# Resize width
|
| 206 |
+
else:
|
| 207 |
+
k = (1, kernel_size)
|
| 208 |
+
h_out = x.size(-2)
|
| 209 |
+
w_out = x.size(-1) - kernel_size + 1
|
| 210 |
+
|
| 211 |
+
unfold = F.unfold(x, k)
|
| 212 |
+
unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
|
| 213 |
+
return unfold
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]:
|
| 217 |
+
if x.dim() == 4:
|
| 218 |
+
b, c, h, w = x.size()
|
| 219 |
+
elif x.dim() == 3:
|
| 220 |
+
c, h, w = x.size()
|
| 221 |
+
b = None
|
| 222 |
+
elif x.dim() == 2:
|
| 223 |
+
h, w = x.size()
|
| 224 |
+
b = c = None
|
| 225 |
+
else:
|
| 226 |
+
raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))
|
| 227 |
+
|
| 228 |
+
x = x.view(-1, 1, h, w)
|
| 229 |
+
return x, b, c, h, w
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
|
| 233 |
+
rh = x.size(-2)
|
| 234 |
+
rw = x.size(-1)
|
| 235 |
+
# Back to the original dimension
|
| 236 |
+
if b is not None:
|
| 237 |
+
x = x.view(b, c, rh, rw) # 4-dim
|
| 238 |
+
else:
|
| 239 |
+
if c is not None:
|
| 240 |
+
x = x.view(c, rh, rw) # 3-dim
|
| 241 |
+
else:
|
| 242 |
+
x = x.view(rh, rw) # 2-dim
|
| 243 |
+
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
|
| 248 |
+
if x.dtype != torch.float32 or x.dtype != torch.float64:
|
| 249 |
+
dtype = x.dtype
|
| 250 |
+
x = x.float()
|
| 251 |
+
else:
|
| 252 |
+
dtype = None
|
| 253 |
+
|
| 254 |
+
return x, dtype
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:
|
| 258 |
+
if dtype is not None:
|
| 259 |
+
if not dtype.is_floating_point:
|
| 260 |
+
x = x.round()
|
| 261 |
+
# To prevent over/underflow when converting types
|
| 262 |
+
if dtype is torch.uint8:
|
| 263 |
+
x = x.clamp(0, 255)
|
| 264 |
+
|
| 265 |
+
x = x.to(dtype=dtype)
|
| 266 |
+
|
| 267 |
+
return x
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def resize_1d(
|
| 271 |
+
x: torch.Tensor,
|
| 272 |
+
dim: int,
|
| 273 |
+
size: int,
|
| 274 |
+
scale: float,
|
| 275 |
+
kernel: str = 'cubic',
|
| 276 |
+
sigma: float = 2.0,
|
| 277 |
+
padding_type: str = 'reflect',
|
| 278 |
+
antialiasing: bool = True) -> torch.Tensor:
|
| 279 |
+
'''
|
| 280 |
+
Args:
|
| 281 |
+
x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
|
| 282 |
+
dim (int):
|
| 283 |
+
scale (float):
|
| 284 |
+
size (int):
|
| 285 |
+
Return:
|
| 286 |
+
'''
|
| 287 |
+
# Identity case
|
| 288 |
+
if scale == 1:
|
| 289 |
+
return x
|
| 290 |
+
|
| 291 |
+
# Default bicubic kernel with antialiasing (only when downsampling)
|
| 292 |
+
if kernel == 'cubic':
|
| 293 |
+
kernel_size = 4
|
| 294 |
+
else:
|
| 295 |
+
kernel_size = math.floor(6 * sigma)
|
| 296 |
+
|
| 297 |
+
if antialiasing and (scale < 1):
|
| 298 |
+
antialiasing_factor = scale
|
| 299 |
+
kernel_size = math.ceil(kernel_size / antialiasing_factor)
|
| 300 |
+
else:
|
| 301 |
+
antialiasing_factor = 1
|
| 302 |
+
|
| 303 |
+
# We allow margin to both sizes
|
| 304 |
+
kernel_size += 2
|
| 305 |
+
|
| 306 |
+
# Weights only depend on the shape of input and output,
|
| 307 |
+
# so we do not calculate gradients here.
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
pos = torch.linspace(
|
| 310 |
+
0, size - 1, steps=size, dtype=x.dtype, device=x.device,
|
| 311 |
+
)
|
| 312 |
+
pos = (pos + 0.5) / scale - 0.5
|
| 313 |
+
base = pos.floor() - (kernel_size // 2) + 1
|
| 314 |
+
dist = pos - base
|
| 315 |
+
weight = get_weight(
|
| 316 |
+
dist,
|
| 317 |
+
kernel_size,
|
| 318 |
+
kernel=kernel,
|
| 319 |
+
sigma=sigma,
|
| 320 |
+
antialiasing_factor=antialiasing_factor,
|
| 321 |
+
)
|
| 322 |
+
pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))
|
| 323 |
+
|
| 324 |
+
# To backpropagate through x
|
| 325 |
+
x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)
|
| 326 |
+
unfold = reshape_tensor(x_pad, dim, kernel_size)
|
| 327 |
+
# Subsampling first
|
| 328 |
+
if dim == 2 or dim == -2:
|
| 329 |
+
sample = unfold[..., base, :]
|
| 330 |
+
weight = weight.view(1, kernel_size, sample.size(2), 1)
|
| 331 |
+
else:
|
| 332 |
+
sample = unfold[..., base]
|
| 333 |
+
weight = weight.view(1, kernel_size, 1, sample.size(3))
|
| 334 |
+
|
| 335 |
+
# Apply the kernel
|
| 336 |
+
x = sample * weight
|
| 337 |
+
x = x.sum(dim=1, keepdim=True)
|
| 338 |
+
return x
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def downsampling_2d(
|
| 342 |
+
x: torch.Tensor,
|
| 343 |
+
k: torch.Tensor,
|
| 344 |
+
scale: int,
|
| 345 |
+
padding_type: str = 'reflect') -> torch.Tensor:
|
| 346 |
+
c = x.size(1)
|
| 347 |
+
k_h = k.size(-2)
|
| 348 |
+
k_w = k.size(-1)
|
| 349 |
+
|
| 350 |
+
k = k.to(dtype=x.dtype, device=x.device)
|
| 351 |
+
k = k.view(1, 1, k_h, k_w)
|
| 352 |
+
k = k.repeat(c, c, 1, 1)
|
| 353 |
+
e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
|
| 354 |
+
e = e.view(c, c, 1, 1)
|
| 355 |
+
k = k * e
|
| 356 |
+
|
| 357 |
+
pad_h = (k_h - scale) // 2
|
| 358 |
+
pad_w = (k_w - scale) // 2
|
| 359 |
+
x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)
|
| 360 |
+
x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)
|
| 361 |
+
y = F.conv2d(x, k, padding=0, stride=scale)
|
| 362 |
+
return y
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def imresize(
|
| 366 |
+
x: torch.Tensor,
|
| 367 |
+
scale: typing.Optional[float] = None,
|
| 368 |
+
sizes: typing.Optional[typing.Tuple[int, int]] = None,
|
| 369 |
+
kernel: typing.Union[str, torch.Tensor] = 'cubic',
|
| 370 |
+
sigma: float = 2,
|
| 371 |
+
rotation_degree: float = 0,
|
| 372 |
+
padding_type: str = 'reflect',
|
| 373 |
+
antialiasing: bool = True) -> torch.Tensor:
|
| 374 |
+
"""
|
| 375 |
+
Args:
|
| 376 |
+
x (torch.Tensor):
|
| 377 |
+
scale (float):
|
| 378 |
+
sizes (tuple(int, int)):
|
| 379 |
+
kernel (str, default='cubic'):
|
| 380 |
+
sigma (float, default=2):
|
| 381 |
+
rotation_degree (float, default=0):
|
| 382 |
+
padding_type (str, default='reflect'):
|
| 383 |
+
antialiasing (bool, default=True):
|
| 384 |
+
Return:
|
| 385 |
+
torch.Tensor:
|
| 386 |
+
"""
|
| 387 |
+
if scale is None and sizes is None:
|
| 388 |
+
raise ValueError('One of scale or sizes must be specified!')
|
| 389 |
+
if scale is not None and sizes is not None:
|
| 390 |
+
raise ValueError('Please specify scale or sizes to avoid conflict!')
|
| 391 |
+
|
| 392 |
+
x, b, c, h, w = reshape_input(x)
|
| 393 |
+
|
| 394 |
+
if sizes is None and scale is not None:
|
| 395 |
+
'''
|
| 396 |
+
# Check if we can apply the convolution algorithm
|
| 397 |
+
scale_inv = 1 / scale
|
| 398 |
+
if isinstance(kernel, str) and scale_inv.is_integer():
|
| 399 |
+
kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
|
| 400 |
+
elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
|
| 401 |
+
raise ValueError(
|
| 402 |
+
'An integer downsampling factor '
|
| 403 |
+
'should be used with a predefined kernel!'
|
| 404 |
+
)
|
| 405 |
+
'''
|
| 406 |
+
# Determine output size
|
| 407 |
+
sizes = (math.ceil(h * scale), math.ceil(w * scale))
|
| 408 |
+
scales = (scale, scale)
|
| 409 |
+
|
| 410 |
+
if scale is None and sizes is not None:
|
| 411 |
+
scales = (sizes[0] / h, sizes[1] / w)
|
| 412 |
+
|
| 413 |
+
x, dtype = cast_input(x)
|
| 414 |
+
|
| 415 |
+
if isinstance(kernel, str) and sizes is not None:
|
| 416 |
+
# Core resizing module
|
| 417 |
+
x = resize_1d(x, -2, size=sizes[0], scale=scales[0], kernel=kernel, sigma=sigma, padding_type=padding_type,
|
| 418 |
+
antialiasing=antialiasing)
|
| 419 |
+
x = resize_1d(x, -1, size=sizes[1], scale=scales[1], kernel=kernel, sigma=sigma, padding_type=padding_type,
|
| 420 |
+
antialiasing=antialiasing)
|
| 421 |
+
elif isinstance(kernel, torch.Tensor) and scale is not None:
|
| 422 |
+
x = downsampling_2d(x, kernel, scale=int(1 / scale))
|
| 423 |
+
|
| 424 |
+
x = reshape_output(x, b, c)
|
| 425 |
+
x = cast_output(x, dtype)
|
| 426 |
+
return x
|
libs/metric/piq/perceptual.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Implementation of Content loss, Style loss, LPIPS and DISTS metrics
|
| 3 |
+
References:
|
| 4 |
+
.. [1] Gatys, Leon and Ecker, Alexander and Bethge, Matthias
|
| 5 |
+
(2016). A Neural Algorithm of Artistic Style}
|
| 6 |
+
Association for Research in Vision and Ophthalmology (ARVO)
|
| 7 |
+
https://arxiv.org/abs/1508.06576
|
| 8 |
+
.. [2] Zhang, Richard and Isola, Phillip and Efros, et al.
|
| 9 |
+
(2018) The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
|
| 10 |
+
2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition
|
| 11 |
+
https://arxiv.org/abs/1801.03924
|
| 12 |
+
"""
|
| 13 |
+
from typing import List, Union, Collection
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.nn.modules.loss import _Loss
|
| 18 |
+
from torchvision.models import vgg16, vgg19, VGG16_Weights, VGG19_Weights
|
| 19 |
+
|
| 20 |
+
from .utils import _validate_input, _reduce
|
| 21 |
+
from .functional import similarity_map, L2Pool2d
|
| 22 |
+
|
| 23 |
+
# Map VGG names to corresponding number in torchvision layer
|
| 24 |
+
VGG16_LAYERS = {
|
| 25 |
+
"conv1_1": '0', "relu1_1": '1',
|
| 26 |
+
"conv1_2": '2', "relu1_2": '3',
|
| 27 |
+
"pool1": '4',
|
| 28 |
+
"conv2_1": '5', "relu2_1": '6',
|
| 29 |
+
"conv2_2": '7', "relu2_2": '8',
|
| 30 |
+
"pool2": '9',
|
| 31 |
+
"conv3_1": '10', "relu3_1": '11',
|
| 32 |
+
"conv3_2": '12', "relu3_2": '13',
|
| 33 |
+
"conv3_3": '14', "relu3_3": '15',
|
| 34 |
+
"pool3": '16',
|
| 35 |
+
"conv4_1": '17', "relu4_1": '18',
|
| 36 |
+
"conv4_2": '19', "relu4_2": '20',
|
| 37 |
+
"conv4_3": '21', "relu4_3": '22',
|
| 38 |
+
"pool4": '23',
|
| 39 |
+
"conv5_1": '24', "relu5_1": '25',
|
| 40 |
+
"conv5_2": '26', "relu5_2": '27',
|
| 41 |
+
"conv5_3": '28', "relu5_3": '29',
|
| 42 |
+
"pool5": '30',
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
VGG19_LAYERS = {
|
| 46 |
+
"conv1_1": '0', "relu1_1": '1',
|
| 47 |
+
"conv1_2": '2', "relu1_2": '3',
|
| 48 |
+
"pool1": '4',
|
| 49 |
+
"conv2_1": '5', "relu2_1": '6',
|
| 50 |
+
"conv2_2": '7', "relu2_2": '8',
|
| 51 |
+
"pool2": '9',
|
| 52 |
+
"conv3_1": '10', "relu3_1": '11',
|
| 53 |
+
"conv3_2": '12', "relu3_2": '13',
|
| 54 |
+
"conv3_3": '14', "relu3_3": '15',
|
| 55 |
+
"conv3_4": '16', "relu3_4": '17',
|
| 56 |
+
"pool3": '18',
|
| 57 |
+
"conv4_1": '19', "relu4_1": '20',
|
| 58 |
+
"conv4_2": '21', "relu4_2": '22',
|
| 59 |
+
"conv4_3": '23', "relu4_3": '24',
|
| 60 |
+
"conv4_4": '25', "relu4_4": '26',
|
| 61 |
+
"pool4": '27',
|
| 62 |
+
"conv5_1": '28', "relu5_1": '29',
|
| 63 |
+
"conv5_2": '30', "relu5_2": '31',
|
| 64 |
+
"conv5_3": '32', "relu5_3": '33',
|
| 65 |
+
"conv5_4": '34', "relu5_4": '35',
|
| 66 |
+
"pool5": '36',
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 70 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 71 |
+
|
| 72 |
+
# Constant used in feature normalization to avoid zero division
|
| 73 |
+
EPS = 1e-10
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class ContentLoss(_Loss):
|
| 77 |
+
r"""Creates Content loss that can be used for image style transfer or as a measure for image to image tasks.
|
| 78 |
+
Uses pretrained VGG models from torchvision.
|
| 79 |
+
Expects input to be in range [0, 1] or normalized with ImageNet statistics into range [-1, 1]
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``.
|
| 83 |
+
layers: List of strings with layer names. Default: ``'relu3_3'``
|
| 84 |
+
weights: List of float weight to balance different layers
|
| 85 |
+
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
|
| 86 |
+
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
|
| 87 |
+
reduction: Specifies the reduction type:
|
| 88 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
| 89 |
+
mean: List of float values used for data standardization. Default: ImageNet mean.
|
| 90 |
+
If there is no need to normalize data, use [0., 0., 0.].
|
| 91 |
+
std: List of float values used for data standardization. Default: ImageNet std.
|
| 92 |
+
If there is no need to normalize data, use [1., 1., 1.].
|
| 93 |
+
normalize_features: If true, unit-normalize each feature in channel dimension before scaling
|
| 94 |
+
and computing distance. See references for details.
|
| 95 |
+
|
| 96 |
+
Examples:
|
| 97 |
+
>>> loss = ContentLoss()
|
| 98 |
+
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
| 99 |
+
>>> y = torch.rand(3, 3, 256, 256)
|
| 100 |
+
>>> output = loss(x, y)
|
| 101 |
+
>>> output.backward()
|
| 102 |
+
|
| 103 |
+
References:
|
| 104 |
+
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
|
| 105 |
+
A Neural Algorithm of Artistic Style
|
| 106 |
+
Association for Research in Vision and Ophthalmology (ARVO)
|
| 107 |
+
https://arxiv.org/abs/1508.06576
|
| 108 |
+
|
| 109 |
+
Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
|
| 110 |
+
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
|
| 111 |
+
IEEE/CVF Conference on Computer Vision and Pattern Recognition
|
| 112 |
+
https://arxiv.org/abs/1801.03924
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(self, feature_extractor: Union[str, torch.nn.Module] = "vgg16", layers: Collection[str] = ("relu3_3",),
|
| 116 |
+
weights: List[Union[float, torch.Tensor]] = [1.], replace_pooling: bool = False,
|
| 117 |
+
distance: str = "mse", reduction: str = "mean", mean: List[float] = IMAGENET_MEAN,
|
| 118 |
+
std: List[float] = IMAGENET_STD, normalize_features: bool = False,
|
| 119 |
+
allow_layers_weights_mismatch: bool = False) -> None:
|
| 120 |
+
|
| 121 |
+
assert allow_layers_weights_mismatch or len(layers) == len(weights), \
|
| 122 |
+
f'Lengths of provided layers and weighs mismatch ({len(weights)} weights and {len(layers)} layers), ' \
|
| 123 |
+
f'which will cause incorrect results. Please provide weight for each layer.'
|
| 124 |
+
|
| 125 |
+
super().__init__()
|
| 126 |
+
|
| 127 |
+
if callable(feature_extractor):
|
| 128 |
+
self.model = feature_extractor
|
| 129 |
+
self.layers = layers
|
| 130 |
+
else:
|
| 131 |
+
if feature_extractor == "vgg16":
|
| 132 |
+
# self.model = vgg16(pretrained=True, progress=False).features
|
| 133 |
+
self.model = vgg16(weights=VGG16_Weights.DEFAULT, progress=False).features
|
| 134 |
+
self.layers = [VGG16_LAYERS[l] for l in layers]
|
| 135 |
+
elif feature_extractor == "vgg19":
|
| 136 |
+
# self.model = vgg19(pretrained=True, progress=False).features
|
| 137 |
+
self.model = vgg19(weights=VGG19_Weights.DEFAULT, progress=False).features
|
| 138 |
+
self.layers = [VGG19_LAYERS[l] for l in layers]
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError("Unknown feature extractor")
|
| 141 |
+
|
| 142 |
+
if replace_pooling:
|
| 143 |
+
self.model = self.replace_pooling(self.model)
|
| 144 |
+
|
| 145 |
+
# Disable gradients
|
| 146 |
+
for param in self.model.parameters():
|
| 147 |
+
param.requires_grad_(False)
|
| 148 |
+
|
| 149 |
+
self.distance = {
|
| 150 |
+
"mse": nn.MSELoss,
|
| 151 |
+
"mae": nn.L1Loss,
|
| 152 |
+
}[distance](reduction='none')
|
| 153 |
+
|
| 154 |
+
self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in weights]
|
| 155 |
+
|
| 156 |
+
mean = torch.tensor(mean)
|
| 157 |
+
std = torch.tensor(std)
|
| 158 |
+
self.mean = mean.view(1, -1, 1, 1)
|
| 159 |
+
self.std = std.view(1, -1, 1, 1)
|
| 160 |
+
|
| 161 |
+
self.normalize_features = normalize_features
|
| 162 |
+
self.reduction = reduction
|
| 163 |
+
|
| 164 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
r"""Computation of Content loss between feature representations of prediction :math:`x` and
|
| 166 |
+
target :math:`y` tensors.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
x: An input tensor. Shape :math:`(N, C, H, W)`.
|
| 170 |
+
y: A target tensor. Shape :math:`(N, C, H, W)`.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Content loss between feature representations
|
| 174 |
+
"""
|
| 175 |
+
_validate_input([x, y], dim_range=(4, 4), data_range=(0, -1))
|
| 176 |
+
|
| 177 |
+
self.model.to(x)
|
| 178 |
+
x_features = self.get_features(x)
|
| 179 |
+
y_features = self.get_features(y)
|
| 180 |
+
|
| 181 |
+
distances = self.compute_distance(x_features, y_features)
|
| 182 |
+
|
| 183 |
+
# Scale distances, then average in spatial dimensions, then stack and sum in channels dimension
|
| 184 |
+
loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1)
|
| 185 |
+
|
| 186 |
+
return _reduce(loss, self.reduction)
|
| 187 |
+
|
| 188 |
+
def compute_distance(self, x_features: List[torch.Tensor], y_features: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 189 |
+
r"""Take L2 or L1 distance between feature maps depending on ``distance``.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
x_features: Features of the input tensor.
|
| 193 |
+
y_features: Features of the target tensor.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Distance between feature maps
|
| 197 |
+
"""
|
| 198 |
+
return [self.distance(x, y) for x, y in zip(x_features, y_features)]
|
| 199 |
+
|
| 200 |
+
def get_features(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 201 |
+
r"""
|
| 202 |
+
Args:
|
| 203 |
+
x: Tensor. Shape :math:`(N, C, H, W)`.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
List of features extracted from intermediate layers
|
| 207 |
+
"""
|
| 208 |
+
# Normalize input
|
| 209 |
+
x = (x - self.mean.to(x)) / self.std.to(x)
|
| 210 |
+
|
| 211 |
+
features = []
|
| 212 |
+
for name, module in self.model._modules.items():
|
| 213 |
+
x = module(x)
|
| 214 |
+
if name in self.layers:
|
| 215 |
+
features.append(self.normalize(x) if self.normalize_features else x)
|
| 216 |
+
|
| 217 |
+
return features
|
| 218 |
+
|
| 219 |
+
@staticmethod
|
| 220 |
+
def normalize(x: torch.Tensor) -> torch.Tensor:
|
| 221 |
+
r"""Normalize feature maps in channel direction to unit length.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
x: Tensor. Shape :math:`(N, C, H, W)`.
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Normalized input
|
| 228 |
+
"""
|
| 229 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
| 230 |
+
return x / (norm_factor + EPS)
|
| 231 |
+
|
| 232 |
+
def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
|
| 233 |
+
r"""Turn All MaxPool layers into AveragePool
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
module: Module to change MaxPool int AveragePool
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
Module with AveragePool instead MaxPool
|
| 240 |
+
|
| 241 |
+
"""
|
| 242 |
+
module_output = module
|
| 243 |
+
if isinstance(module, torch.nn.MaxPool2d):
|
| 244 |
+
module_output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
| 245 |
+
|
| 246 |
+
for name, child in module.named_children():
|
| 247 |
+
module_output.add_module(name, self.replace_pooling(child))
|
| 248 |
+
return module_output
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class StyleLoss(ContentLoss):
|
| 252 |
+
r"""Creates Style loss that can be used for image style transfer or as a measure in
|
| 253 |
+
image to image tasks. Computes distance between Gram matrices of feature maps.
|
| 254 |
+
Uses pretrained VGG models from torchvision.
|
| 255 |
+
|
| 256 |
+
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
|
| 257 |
+
If no normalisation is required, change `mean` and `std` values accordingly.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``.
|
| 261 |
+
layers: List of strings with layer names. Default: ``'relu3_3'``
|
| 262 |
+
weights: List of float weight to balance different layers
|
| 263 |
+
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
|
| 264 |
+
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
|
| 265 |
+
reduction: Specifies the reduction type:
|
| 266 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
| 267 |
+
mean: List of float values used for data standardization. Default: ImageNet mean.
|
| 268 |
+
If there is no need to normalize data, use [0., 0., 0.].
|
| 269 |
+
std: List of float values used for data standardization. Default: ImageNet std.
|
| 270 |
+
If there is no need to normalize data, use [1., 1., 1.].
|
| 271 |
+
normalize_features: If true, unit-normalize each feature in channel dimension before scaling
|
| 272 |
+
and computing distance. See references for details.
|
| 273 |
+
|
| 274 |
+
Examples:
|
| 275 |
+
>>> loss = StyleLoss()
|
| 276 |
+
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
| 277 |
+
>>> y = torch.rand(3, 3, 256, 256)
|
| 278 |
+
>>> output = loss(x, y)
|
| 279 |
+
>>> output.backward()
|
| 280 |
+
|
| 281 |
+
References:
|
| 282 |
+
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
|
| 283 |
+
A Neural Algorithm of Artistic Style
|
| 284 |
+
Association for Research in Vision and Ophthalmology (ARVO)
|
| 285 |
+
https://arxiv.org/abs/1508.06576
|
| 286 |
+
|
| 287 |
+
Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
|
| 288 |
+
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
|
| 289 |
+
IEEE/CVF Conference on Computer Vision and Pattern Recognition
|
| 290 |
+
https://arxiv.org/abs/1801.03924
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor):
|
| 294 |
+
r"""Take L2 or L1 distance between Gram matrices of feature maps depending on ``distance``.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
x_features: Features of the input tensor.
|
| 298 |
+
y_features: Features of the target tensor.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
Distance between Gram matrices
|
| 302 |
+
"""
|
| 303 |
+
x_gram = [self.gram_matrix(x) for x in x_features]
|
| 304 |
+
y_gram = [self.gram_matrix(x) for x in y_features]
|
| 305 |
+
return [self.distance(x, y) for x, y in zip(x_gram, y_gram)]
|
| 306 |
+
|
| 307 |
+
@staticmethod
|
| 308 |
+
def gram_matrix(x: torch.Tensor) -> torch.Tensor:
|
| 309 |
+
r"""Compute Gram matrix for batch of features.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
x: Tensor. Shape :math:`(N, C, H, W)`.
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
Gram matrix for given input
|
| 316 |
+
"""
|
| 317 |
+
B, C, H, W = x.size()
|
| 318 |
+
gram = []
|
| 319 |
+
for i in range(B):
|
| 320 |
+
features = x[i].view(C, H * W)
|
| 321 |
+
|
| 322 |
+
# Add fake channel dimension
|
| 323 |
+
gram.append(torch.mm(features, features.t()).unsqueeze(0))
|
| 324 |
+
|
| 325 |
+
return torch.stack(gram)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class LPIPS(ContentLoss):
|
| 329 |
+
r"""Learned Perceptual Image Patch Similarity metric. Only VGG16 learned weights are supported.
|
| 330 |
+
|
| 331 |
+
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
|
| 332 |
+
If no normalisation is required, change `mean` and `std` values accordingly.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
|
| 336 |
+
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
|
| 337 |
+
reduction: Specifies the reduction type:
|
| 338 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
| 339 |
+
mean: List of float values used for data standardization. Default: ImageNet mean.
|
| 340 |
+
If there is no need to normalize data, use [0., 0., 0.].
|
| 341 |
+
std: List of float values used for data standardization. Default: ImageNet std.
|
| 342 |
+
If there is no need to normalize data, use [1., 1., 1.].
|
| 343 |
+
|
| 344 |
+
Examples:
|
| 345 |
+
>>> loss = LPIPS()
|
| 346 |
+
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
| 347 |
+
>>> y = torch.rand(3, 3, 256, 256)
|
| 348 |
+
>>> output = loss(x, y)
|
| 349 |
+
>>> output.backward()
|
| 350 |
+
|
| 351 |
+
References:
|
| 352 |
+
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
|
| 353 |
+
A Neural Algorithm of Artistic Style
|
| 354 |
+
Association for Research in Vision and Ophthalmology (ARVO)
|
| 355 |
+
https://arxiv.org/abs/1508.06576
|
| 356 |
+
|
| 357 |
+
Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
|
| 358 |
+
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
|
| 359 |
+
IEEE/CVF Conference on Computer Vision and Pattern Recognition
|
| 360 |
+
https://arxiv.org/abs/1801.03924
|
| 361 |
+
https://github.com/richzhang/PerceptualSimilarity
|
| 362 |
+
"""
|
| 363 |
+
_weights_url = "https://github.com/photosynthesis-team/" + \
|
| 364 |
+
"photosynthesis.metrics/releases/download/v0.4.0/lpips_weights.pt"
|
| 365 |
+
|
| 366 |
+
def __init__(self, replace_pooling: bool = False, distance: str = "mse", reduction: str = "mean",
|
| 367 |
+
mean: List[float] = IMAGENET_MEAN, std: List[float] = IMAGENET_STD, ) -> None:
|
| 368 |
+
lpips_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
|
| 369 |
+
lpips_weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
|
| 370 |
+
super().__init__("vgg16", layers=lpips_layers, weights=lpips_weights,
|
| 371 |
+
replace_pooling=replace_pooling, distance=distance,
|
| 372 |
+
reduction=reduction, mean=mean, std=std,
|
| 373 |
+
normalize_features=True)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class DISTS(ContentLoss):
|
| 377 |
+
r"""Deep Image Structure and Texture Similarity metric.
|
| 378 |
+
|
| 379 |
+
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
|
| 380 |
+
If no normalisation is required, change `mean` and `std` values accordingly.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
reduction: Specifies the reduction type:
|
| 384 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
| 385 |
+
mean: List of float values used for data standardization. Default: ImageNet mean.
|
| 386 |
+
If there is no need to normalize data, use [0., 0., 0.].
|
| 387 |
+
std: List of float values used for data standardization. Default: ImageNet std.
|
| 388 |
+
If there is no need to normalize data, use [1., 1., 1.].
|
| 389 |
+
|
| 390 |
+
Examples:
|
| 391 |
+
>>> loss = DISTS()
|
| 392 |
+
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
| 393 |
+
>>> y = torch.rand(3, 3, 256, 256)
|
| 394 |
+
>>> output = loss(x, y)
|
| 395 |
+
>>> output.backward()
|
| 396 |
+
|
| 397 |
+
References:
|
| 398 |
+
Keyan Ding, Kede Ma, Shiqi Wang, Eero P. Simoncelli (2020).
|
| 399 |
+
Image Quality Assessment: Unifying Structure and Texture Similarity.
|
| 400 |
+
https://arxiv.org/abs/2004.07728
|
| 401 |
+
https://github.com/dingkeyan93/DISTS
|
| 402 |
+
"""
|
| 403 |
+
_weights_url = "https://github.com/photosynthesis-team/piq/releases/download/v0.4.1/dists_weights.pt"
|
| 404 |
+
|
| 405 |
+
def __init__(self, reduction: str = "mean", mean: List[float] = IMAGENET_MEAN,
|
| 406 |
+
std: List[float] = IMAGENET_STD) -> None:
|
| 407 |
+
dists_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
|
| 408 |
+
channels = [3, 64, 128, 256, 512, 512]
|
| 409 |
+
|
| 410 |
+
weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
|
| 411 |
+
dists_weights = list(torch.split(weights['alpha'], channels, dim=1))
|
| 412 |
+
dists_weights.extend(torch.split(weights['beta'], channels, dim=1))
|
| 413 |
+
|
| 414 |
+
super().__init__("vgg16", layers=dists_layers, weights=dists_weights,
|
| 415 |
+
replace_pooling=True, reduction=reduction, mean=mean, std=std,
|
| 416 |
+
normalize_features=False, allow_layers_weights_mismatch=True)
|
| 417 |
+
|
| 418 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 419 |
+
r"""
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
x: An input tensor. Shape :math:`(N, C, H, W)`.
|
| 423 |
+
y: A target tensor. Shape :math:`(N, C, H, W)`.
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
Deep Image Structure and Texture Similarity loss, i.e. ``1-DISTS`` in range [0, 1].
|
| 427 |
+
"""
|
| 428 |
+
_, _, H, W = x.shape
|
| 429 |
+
|
| 430 |
+
if min(H, W) > 256:
|
| 431 |
+
x = torch.nn.functional.interpolate(
|
| 432 |
+
x, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
|
| 433 |
+
y = torch.nn.functional.interpolate(
|
| 434 |
+
y, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
|
| 435 |
+
|
| 436 |
+
loss = super().forward(x, y)
|
| 437 |
+
return 1 - loss
|
| 438 |
+
|
| 439 |
+
def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor) -> List[torch.Tensor]:
|
| 440 |
+
r"""Compute structure similarity between feature maps
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
x_features: Features of the input tensor.
|
| 444 |
+
y_features: Features of the target tensor.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
Structural similarity distance between feature maps
|
| 448 |
+
"""
|
| 449 |
+
structure_distance, texture_distance = [], []
|
| 450 |
+
# Small constant for numerical stability
|
| 451 |
+
EPS = 1e-6
|
| 452 |
+
|
| 453 |
+
for x, y in zip(x_features, y_features):
|
| 454 |
+
x_mean = x.mean([2, 3], keepdim=True)
|
| 455 |
+
y_mean = y.mean([2, 3], keepdim=True)
|
| 456 |
+
structure_distance.append(similarity_map(x_mean, y_mean, constant=EPS))
|
| 457 |
+
|
| 458 |
+
x_var = ((x - x_mean) ** 2).mean([2, 3], keepdim=True)
|
| 459 |
+
y_var = ((y - y_mean) ** 2).mean([2, 3], keepdim=True)
|
| 460 |
+
xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean
|
| 461 |
+
texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS))
|
| 462 |
+
|
| 463 |
+
return structure_distance + texture_distance
|
| 464 |
+
|
| 465 |
+
def get_features(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 466 |
+
r"""
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
x: Input tensor
|
| 470 |
+
|
| 471 |
+
Returns:
|
| 472 |
+
List of features extracted from input tensor
|
| 473 |
+
"""
|
| 474 |
+
features = super().get_features(x)
|
| 475 |
+
|
| 476 |
+
# Add input tensor as an additional feature
|
| 477 |
+
features.insert(0, x)
|
| 478 |
+
return features
|
| 479 |
+
|
| 480 |
+
def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
|
| 481 |
+
r"""Turn All MaxPool layers into L2Pool
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
module: Module to change MaxPool into L2Pool
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
Module with L2Pool instead of MaxPool
|
| 488 |
+
"""
|
| 489 |
+
module_output = module
|
| 490 |
+
if isinstance(module, torch.nn.MaxPool2d):
|
| 491 |
+
module_output = L2Pool2d(kernel_size=3, stride=2, padding=1)
|
| 492 |
+
|
| 493 |
+
for name, child in module.named_children():
|
| 494 |
+
module_output.add_module(name, self.replace_pooling(child))
|
| 495 |
+
|
| 496 |
+
return module_output
|
libs/metric/piq/utils/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .common import _validate_input, _reduce, _parse_version
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"_validate_input",
|
| 5 |
+
"_reduce",
|
| 6 |
+
'_parse_version'
|
| 7 |
+
]
|
libs/metric/piq/utils/common.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import re
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
from typing import Tuple, List, Optional, Union, Dict, Any
|
| 6 |
+
|
| 7 |
+
SEMVER_VERSION_PATTERN = re.compile(
|
| 8 |
+
r"""
|
| 9 |
+
^
|
| 10 |
+
(?P<major>0|[1-9]\d*)
|
| 11 |
+
\.
|
| 12 |
+
(?P<minor>0|[1-9]\d*)
|
| 13 |
+
\.
|
| 14 |
+
(?P<patch>0|[1-9]\d*)
|
| 15 |
+
(?:-(?P<prerelease>
|
| 16 |
+
(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)
|
| 17 |
+
(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*
|
| 18 |
+
))?
|
| 19 |
+
(?:\+(?P<build>
|
| 20 |
+
[0-9a-zA-Z-]+
|
| 21 |
+
(?:\.[0-9a-zA-Z-]+)*
|
| 22 |
+
))?
|
| 23 |
+
$
|
| 24 |
+
""",
|
| 25 |
+
re.VERBOSE,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
PEP_440_VERSION_PATTERN = r"""
|
| 30 |
+
v?
|
| 31 |
+
(?:
|
| 32 |
+
(?:(?P<epoch>[0-9]+)!)? # epoch
|
| 33 |
+
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
|
| 34 |
+
(?P<pre> # pre-release
|
| 35 |
+
[-_\.]?
|
| 36 |
+
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
|
| 37 |
+
[-_\.]?
|
| 38 |
+
(?P<pre_n>[0-9]+)?
|
| 39 |
+
)?
|
| 40 |
+
(?P<post> # post release
|
| 41 |
+
(?:-(?P<post_n1>[0-9]+))
|
| 42 |
+
|
|
| 43 |
+
(?:
|
| 44 |
+
[-_\.]?
|
| 45 |
+
(?P<post_l>post|rev|r)
|
| 46 |
+
[-_\.]?
|
| 47 |
+
(?P<post_n2>[0-9]+)?
|
| 48 |
+
)
|
| 49 |
+
)?
|
| 50 |
+
(?P<dev> # dev release
|
| 51 |
+
[-_\.]?
|
| 52 |
+
(?P<dev_l>dev)
|
| 53 |
+
[-_\.]?
|
| 54 |
+
(?P<dev_n>[0-9]+)?
|
| 55 |
+
)?
|
| 56 |
+
)
|
| 57 |
+
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _validate_input(
|
| 62 |
+
tensors: List[torch.Tensor],
|
| 63 |
+
dim_range: Tuple[int, int] = (0, -1),
|
| 64 |
+
data_range: Tuple[float, float] = (0., -1.),
|
| 65 |
+
# size_dim_range: Tuple[float, float] = (0., -1.),
|
| 66 |
+
size_range: Optional[Tuple[int, int]] = None,
|
| 67 |
+
) -> None:
|
| 68 |
+
r"""Check that input(-s) satisfies the requirements
|
| 69 |
+
Args:
|
| 70 |
+
tensors: Tensors to check
|
| 71 |
+
dim_range: Allowed number of dimensions. (min, max)
|
| 72 |
+
data_range: Allowed range of values in tensors. (min, max)
|
| 73 |
+
size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
if not __debug__:
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
x = tensors[0]
|
| 80 |
+
|
| 81 |
+
for t in tensors:
|
| 82 |
+
assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
|
| 83 |
+
assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
|
| 84 |
+
|
| 85 |
+
if size_range is None:
|
| 86 |
+
assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
|
| 87 |
+
else:
|
| 88 |
+
assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
|
| 89 |
+
f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
|
| 90 |
+
|
| 91 |
+
if dim_range[0] == dim_range[1]:
|
| 92 |
+
assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
|
| 93 |
+
elif dim_range[0] < dim_range[1]:
|
| 94 |
+
assert dim_range[0] <= t.dim() <= dim_range[1], \
|
| 95 |
+
f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
|
| 96 |
+
|
| 97 |
+
if data_range[0] < data_range[1]:
|
| 98 |
+
assert data_range[0] <= t.min(), \
|
| 99 |
+
f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
|
| 100 |
+
assert t.max() <= data_range[1], \
|
| 101 |
+
f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
|
| 105 |
+
r"""Reduce input in batch dimension if needed.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
x: Tensor with shape (N, *).
|
| 109 |
+
reduction: Specifies the reduction type:
|
| 110 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
|
| 111 |
+
"""
|
| 112 |
+
if reduction == 'none':
|
| 113 |
+
return x
|
| 114 |
+
elif reduction == 'mean':
|
| 115 |
+
return x.mean(dim=0)
|
| 116 |
+
elif reduction == 'sum':
|
| 117 |
+
return x.sum(dim=0)
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _parse_version(version: Union[str, bytes]) -> Tuple[int, ...]:
|
| 123 |
+
""" Parses valid Python versions according to Semver and PEP 440 specifications.
|
| 124 |
+
For more on Semver check: https://semver.org/
|
| 125 |
+
For more on PEP 440 check: https://www.python.org/dev/peps/pep-0440/.
|
| 126 |
+
|
| 127 |
+
Implementation is inspired by:
|
| 128 |
+
- https://github.com/python-semver
|
| 129 |
+
- https://github.com/pypa/packaging
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
version: unparsed information about the library of interest.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
parsed information about the library of interest.
|
| 136 |
+
"""
|
| 137 |
+
if isinstance(version, bytes):
|
| 138 |
+
version = version.decode("UTF-8")
|
| 139 |
+
elif not isinstance(version, str) and not isinstance(version, bytes):
|
| 140 |
+
raise TypeError(f"not expecting type {type(version)}")
|
| 141 |
+
|
| 142 |
+
# Semver processing
|
| 143 |
+
match = SEMVER_VERSION_PATTERN.match(version)
|
| 144 |
+
if match:
|
| 145 |
+
matched_version_parts: Dict[str, Any] = match.groupdict()
|
| 146 |
+
release = tuple([int(matched_version_parts[k]) for k in ['major', 'minor', 'patch']])
|
| 147 |
+
return release
|
| 148 |
+
|
| 149 |
+
# PEP 440 processing
|
| 150 |
+
regex = re.compile(r"^\s*" + PEP_440_VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
| 151 |
+
match = regex.search(version)
|
| 152 |
+
|
| 153 |
+
if match is None:
|
| 154 |
+
warnings.warn(f"{version} is not a valid SemVer or PEP 440 string")
|
| 155 |
+
return tuple()
|
| 156 |
+
|
| 157 |
+
release = tuple(int(i) for i in match.group("release").split("."))
|
| 158 |
+
return release
|
libs/metric/pytorch_fid/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = '0.3.0'
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange, repeat
|
| 5 |
+
|
| 6 |
+
from .inception import InceptionV3
|
| 7 |
+
from .fid_score import calculate_frechet_distance
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PytorchFIDFactory(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
channels:
|
| 15 |
+
inception_block_idx:
|
| 16 |
+
|
| 17 |
+
Examples:
|
| 18 |
+
>>> fid_factory = PytorchFIDFactory()
|
| 19 |
+
>>> fid_score = fid_factory.score(real_samples=data, fake_samples=all_images)
|
| 20 |
+
>>> print(fid_score)
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, channels: int = 3, inception_block_idx: int = 2048):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.channels = channels
|
| 26 |
+
|
| 27 |
+
# load models
|
| 28 |
+
assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
|
| 29 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
|
| 30 |
+
self.inception_v3 = InceptionV3([block_idx])
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
| 33 |
+
def calculate_activation_statistics(self, samples):
|
| 34 |
+
features = self.inception_v3(samples)[0]
|
| 35 |
+
features = rearrange(features, '... 1 1 -> ...')
|
| 36 |
+
|
| 37 |
+
mu = torch.mean(features, dim=0).cpu()
|
| 38 |
+
sigma = torch.cov(features).cpu()
|
| 39 |
+
return mu, sigma
|
| 40 |
+
|
| 41 |
+
def score(self, real_samples, fake_samples):
|
| 42 |
+
if self.channels == 1:
|
| 43 |
+
real_samples, fake_samples = map(
|
| 44 |
+
lambda t: repeat(t, 'b 1 ... -> b c ...', c=3), (real_samples, fake_samples)
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
min_batch = min(real_samples.shape[0], fake_samples.shape[0])
|
| 48 |
+
real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples))
|
| 49 |
+
|
| 50 |
+
m1, s1 = self.calculate_activation_statistics(real_samples)
|
| 51 |
+
m2, s2 = self.calculate_activation_statistics(fake_samples)
|
| 52 |
+
|
| 53 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
| 54 |
+
return fid_value
|
libs/metric/pytorch_fid/fid_score.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
| 2 |
+
|
| 3 |
+
The FID metric calculates the distance between two distributions of images.
|
| 4 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
| 5 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
| 6 |
+
|
| 7 |
+
When run as a stand-alone program, it compares the distribution of
|
| 8 |
+
images that are stored as PNG/JPEG at a specified location with a
|
| 9 |
+
distribution given by summary statistics (in pickle format).
|
| 10 |
+
|
| 11 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
| 12 |
+
the pool_3 layer of the inception net for generated samples and real world
|
| 13 |
+
samples respectively.
|
| 14 |
+
|
| 15 |
+
See --help to see further details.
|
| 16 |
+
|
| 17 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
| 18 |
+
of Tensorflow
|
| 19 |
+
|
| 20 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
| 21 |
+
|
| 22 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 23 |
+
you may not use this file except in compliance with the License.
|
| 24 |
+
You may obtain a copy of the License at
|
| 25 |
+
|
| 26 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 27 |
+
|
| 28 |
+
Unless required by applicable law or agreed to in writing, software
|
| 29 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 30 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 31 |
+
See the License for the specific language governing permissions and
|
| 32 |
+
limitations under the License.
|
| 33 |
+
"""
|
| 34 |
+
import os
|
| 35 |
+
import pathlib
|
| 36 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
| 37 |
+
|
| 38 |
+
import numpy as np
|
| 39 |
+
import torch
|
| 40 |
+
import torchvision.transforms as TF
|
| 41 |
+
from PIL import Image
|
| 42 |
+
from scipy import linalg
|
| 43 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
from tqdm import tqdm
|
| 47 |
+
except ImportError:
|
| 48 |
+
# If tqdm is not available, provide a mock version of it
|
| 49 |
+
def tqdm(x):
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
from .inception import InceptionV3
|
| 53 |
+
|
| 54 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
| 55 |
+
parser.add_argument('--batch-size', type=int, default=50,
|
| 56 |
+
help='Batch size to use')
|
| 57 |
+
parser.add_argument('--num-workers', type=int,
|
| 58 |
+
help=('Number of processes to use for data loading. '
|
| 59 |
+
'Defaults to `min(8, num_cpus)`'))
|
| 60 |
+
parser.add_argument('--device', type=str, default=None,
|
| 61 |
+
help='Device to use. Like cuda, cuda:0 or cpu')
|
| 62 |
+
parser.add_argument('--dims', type=int, default=2048,
|
| 63 |
+
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
|
| 64 |
+
help=('Dimensionality of Inception features to use. '
|
| 65 |
+
'By default, uses pool3 features'))
|
| 66 |
+
parser.add_argument('--save-stats', action='store_true',
|
| 67 |
+
help=('Generate an npz archive from a directory of samples. '
|
| 68 |
+
'The first path is used as input and the second as output.'))
|
| 69 |
+
parser.add_argument('path', type=str, nargs=2,
|
| 70 |
+
help=('Paths to the generated images or '
|
| 71 |
+
'to .npz statistic files'))
|
| 72 |
+
|
| 73 |
+
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
|
| 74 |
+
'tif', 'tiff', 'webp'}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class ImagePathDataset(torch.utils.data.Dataset):
|
| 78 |
+
def __init__(self, files, transforms=None):
|
| 79 |
+
self.files = files
|
| 80 |
+
self.transforms = transforms
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
return len(self.files)
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, i):
|
| 86 |
+
path = self.files[i]
|
| 87 |
+
img = Image.open(path).convert('RGB')
|
| 88 |
+
if self.transforms is not None:
|
| 89 |
+
img = self.transforms(img)
|
| 90 |
+
return img
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
|
| 94 |
+
num_workers=1):
|
| 95 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
| 96 |
+
|
| 97 |
+
Params:
|
| 98 |
+
-- files : List of image files paths
|
| 99 |
+
-- model : Instance of inception model
|
| 100 |
+
-- batch_size : Batch size of images for the model to process at once.
|
| 101 |
+
Make sure that the number of samples is a multiple of
|
| 102 |
+
the batch size, otherwise some samples are ignored. This
|
| 103 |
+
behavior is retained to match the original FID score
|
| 104 |
+
implementation.
|
| 105 |
+
-- dims : Dimensionality of features returned by Inception
|
| 106 |
+
-- device : Device to run calculations
|
| 107 |
+
-- num_workers : Number of parallel dataloader workers
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
| 111 |
+
activations of the given tensor when feeding inception with the
|
| 112 |
+
query tensor.
|
| 113 |
+
"""
|
| 114 |
+
model.eval()
|
| 115 |
+
|
| 116 |
+
if batch_size > len(files):
|
| 117 |
+
print(('Warning: batch size is bigger than the data size. '
|
| 118 |
+
'Setting batch size to data size'))
|
| 119 |
+
batch_size = len(files)
|
| 120 |
+
|
| 121 |
+
dataset = ImagePathDataset(files, transforms=TF.ToTensor())
|
| 122 |
+
dataloader = torch.utils.data.DataLoader(dataset,
|
| 123 |
+
batch_size=batch_size,
|
| 124 |
+
shuffle=False,
|
| 125 |
+
drop_last=False,
|
| 126 |
+
num_workers=num_workers)
|
| 127 |
+
|
| 128 |
+
pred_arr = np.empty((len(files), dims))
|
| 129 |
+
|
| 130 |
+
start_idx = 0
|
| 131 |
+
|
| 132 |
+
for batch in tqdm(dataloader):
|
| 133 |
+
batch = batch.to(device)
|
| 134 |
+
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
pred = model(batch)[0]
|
| 137 |
+
|
| 138 |
+
# If model output is not scalar, apply global spatial average pooling.
|
| 139 |
+
# This happens if you choose a dimensionality not equal 2048.
|
| 140 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
| 141 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
| 142 |
+
|
| 143 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
| 144 |
+
|
| 145 |
+
pred_arr[start_idx:start_idx + pred.shape[0]] = pred
|
| 146 |
+
|
| 147 |
+
start_idx = start_idx + pred.shape[0]
|
| 148 |
+
|
| 149 |
+
return pred_arr
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 153 |
+
"""Numpy implementation of the Frechet Distance.
|
| 154 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
| 155 |
+
and X_2 ~ N(mu_2, C_2) is
|
| 156 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
| 157 |
+
|
| 158 |
+
Stable version by Dougal J. Sutherland.
|
| 159 |
+
|
| 160 |
+
Params:
|
| 161 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
| 162 |
+
inception net (like returned by the function 'get_predictions')
|
| 163 |
+
for generated samples.
|
| 164 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
| 165 |
+
representative data set.
|
| 166 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
| 167 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
| 168 |
+
representative data set.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
-- : The Frechet Distance.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
mu1 = np.atleast_1d(mu1)
|
| 175 |
+
mu2 = np.atleast_1d(mu2)
|
| 176 |
+
|
| 177 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 178 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 179 |
+
|
| 180 |
+
assert mu1.shape == mu2.shape, \
|
| 181 |
+
'Training and test mean vectors have different lengths'
|
| 182 |
+
assert sigma1.shape == sigma2.shape, \
|
| 183 |
+
'Training and test covariances have different dimensions'
|
| 184 |
+
|
| 185 |
+
diff = mu1 - mu2
|
| 186 |
+
|
| 187 |
+
# Product might be almost singular
|
| 188 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 189 |
+
if not np.isfinite(covmean).all():
|
| 190 |
+
msg = ('fid calculation produces singular product; '
|
| 191 |
+
'adding %s to diagonal of cov estimates') % eps
|
| 192 |
+
print(msg)
|
| 193 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 194 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 195 |
+
|
| 196 |
+
# Numerical error might give slight imaginary component
|
| 197 |
+
if np.iscomplexobj(covmean):
|
| 198 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 199 |
+
m = np.max(np.abs(covmean.imag))
|
| 200 |
+
raise ValueError('Imaginary component {}'.format(m))
|
| 201 |
+
covmean = covmean.real
|
| 202 |
+
|
| 203 |
+
tr_covmean = np.trace(covmean)
|
| 204 |
+
|
| 205 |
+
return (diff.dot(diff) + np.trace(sigma1)
|
| 206 |
+
+ np.trace(sigma2) - 2 * tr_covmean)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
|
| 210 |
+
device='cpu', num_workers=1):
|
| 211 |
+
"""Calculation of the statistics used by the FID.
|
| 212 |
+
Params:
|
| 213 |
+
-- files : List of image files paths
|
| 214 |
+
-- model : Instance of inception model
|
| 215 |
+
-- batch_size : The images numpy array is split into batches with
|
| 216 |
+
batch size batch_size. A reasonable batch size
|
| 217 |
+
depends on the hardware.
|
| 218 |
+
-- dims : Dimensionality of features returned by Inception
|
| 219 |
+
-- device : Device to run calculations
|
| 220 |
+
-- num_workers : Number of parallel dataloader workers
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
| 224 |
+
the inception model.
|
| 225 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
| 226 |
+
the inception model.
|
| 227 |
+
"""
|
| 228 |
+
act = get_activations(files, model, batch_size, dims, device, num_workers)
|
| 229 |
+
mu = np.mean(act, axis=0)
|
| 230 |
+
sigma = np.cov(act, rowvar=False)
|
| 231 |
+
return mu, sigma
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def compute_statistics_of_path(path, model, batch_size, dims, device,
|
| 235 |
+
num_workers=1):
|
| 236 |
+
if path.endswith('.npz'):
|
| 237 |
+
with np.load(path) as f:
|
| 238 |
+
m, s = f['mu'][:], f['sigma'][:]
|
| 239 |
+
else:
|
| 240 |
+
path = pathlib.Path(path)
|
| 241 |
+
files = sorted([file for ext in IMAGE_EXTENSIONS
|
| 242 |
+
for file in path.glob('*.{}'.format(ext))])
|
| 243 |
+
m, s = calculate_activation_statistics(files, model, batch_size,
|
| 244 |
+
dims, device, num_workers)
|
| 245 |
+
|
| 246 |
+
return m, s
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
|
| 250 |
+
"""Calculates the FID of two paths"""
|
| 251 |
+
for p in paths:
|
| 252 |
+
if not os.path.exists(p):
|
| 253 |
+
raise RuntimeError('Invalid path: %s' % p)
|
| 254 |
+
|
| 255 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
| 256 |
+
|
| 257 |
+
model = InceptionV3([block_idx]).to(device)
|
| 258 |
+
|
| 259 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
| 260 |
+
dims, device, num_workers)
|
| 261 |
+
m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
|
| 262 |
+
dims, device, num_workers)
|
| 263 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
| 264 |
+
|
| 265 |
+
return fid_value
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
|
| 269 |
+
"""Calculates the FID of two paths"""
|
| 270 |
+
if not os.path.exists(paths[0]):
|
| 271 |
+
raise RuntimeError('Invalid path: %s' % paths[0])
|
| 272 |
+
|
| 273 |
+
if os.path.exists(paths[1]):
|
| 274 |
+
raise RuntimeError('Existing output file: %s' % paths[1])
|
| 275 |
+
|
| 276 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
| 277 |
+
|
| 278 |
+
model = InceptionV3([block_idx]).to(device)
|
| 279 |
+
|
| 280 |
+
print(f"Saving statistics for {paths[0]}")
|
| 281 |
+
|
| 282 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
| 283 |
+
dims, device, num_workers)
|
| 284 |
+
|
| 285 |
+
np.savez_compressed(paths[1], mu=m1, sigma=s1)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def main():
|
| 289 |
+
args = parser.parse_args()
|
| 290 |
+
|
| 291 |
+
if args.device is None:
|
| 292 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
| 293 |
+
else:
|
| 294 |
+
device = torch.device(args.device)
|
| 295 |
+
|
| 296 |
+
if args.num_workers is None:
|
| 297 |
+
try:
|
| 298 |
+
num_cpus = len(os.sched_getaffinity(0))
|
| 299 |
+
except AttributeError:
|
| 300 |
+
# os.sched_getaffinity is not available under Windows, use
|
| 301 |
+
# os.cpu_count instead (which may not return the *available* number
|
| 302 |
+
# of CPUs).
|
| 303 |
+
num_cpus = os.cpu_count()
|
| 304 |
+
|
| 305 |
+
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
|
| 306 |
+
else:
|
| 307 |
+
num_workers = args.num_workers
|
| 308 |
+
|
| 309 |
+
if args.save_stats:
|
| 310 |
+
save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
|
| 311 |
+
return
|
| 312 |
+
|
| 313 |
+
fid_value = calculate_fid_given_paths(args.path,
|
| 314 |
+
args.batch_size,
|
| 315 |
+
device,
|
| 316 |
+
args.dims,
|
| 317 |
+
num_workers)
|
| 318 |
+
print('FID: ', fid_value)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == '__main__':
|
| 322 |
+
main()
|
libs/metric/pytorch_fid/inception.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
| 8 |
+
except ImportError:
|
| 9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
| 10 |
+
|
| 11 |
+
# Inception weights ported to Pytorch from
|
| 12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
| 13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class InceptionV3(nn.Module):
|
| 17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
| 18 |
+
|
| 19 |
+
# Index of default block of inception to return,
|
| 20 |
+
# corresponds to output of final average pooling
|
| 21 |
+
DEFAULT_BLOCK_INDEX = 3
|
| 22 |
+
|
| 23 |
+
# Maps feature dimensionality to their output blocks indices
|
| 24 |
+
BLOCK_INDEX_BY_DIM = {
|
| 25 |
+
64: 0, # First max pooling features
|
| 26 |
+
192: 1, # Second max pooling featurs
|
| 27 |
+
768: 2, # Pre-aux classifier features
|
| 28 |
+
2048: 3 # Final average pooling features
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
def __init__(self,
|
| 32 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
| 33 |
+
resize_input=True,
|
| 34 |
+
normalize_input=True,
|
| 35 |
+
requires_grad=False,
|
| 36 |
+
use_fid_inception=True):
|
| 37 |
+
"""Build pretrained InceptionV3
|
| 38 |
+
|
| 39 |
+
Parameters
|
| 40 |
+
----------
|
| 41 |
+
output_blocks : list of int
|
| 42 |
+
Indices of blocks to return features of. Possible values are:
|
| 43 |
+
- 0: corresponds to output of first max pooling
|
| 44 |
+
- 1: corresponds to output of second max pooling
|
| 45 |
+
- 2: corresponds to output which is fed to aux classifier
|
| 46 |
+
- 3: corresponds to output of final average pooling
|
| 47 |
+
resize_input : bool
|
| 48 |
+
If true, bilinearly resizes input to width and height 299 before
|
| 49 |
+
feeding input to model. As the network without fully connected
|
| 50 |
+
layers is fully convolutional, it should be able to handle inputs
|
| 51 |
+
of arbitrary size, so resizing might not be strictly needed
|
| 52 |
+
normalize_input : bool
|
| 53 |
+
If true, scales the input from range (0, 1) to the range the
|
| 54 |
+
pretrained Inception network expects, namely (-1, 1)
|
| 55 |
+
requires_grad : bool
|
| 56 |
+
If true, parameters of the model require gradients. Possibly useful
|
| 57 |
+
for finetuning the network
|
| 58 |
+
use_fid_inception : bool
|
| 59 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
| 60 |
+
FID implementation. If false, uses the pretrained Inception model
|
| 61 |
+
available in torchvision. The FID Inception model has different
|
| 62 |
+
weights and a slightly different structure from torchvision's
|
| 63 |
+
Inception model. If you want to compute FID scores, you are
|
| 64 |
+
strongly advised to set this parameter to true to get comparable
|
| 65 |
+
results.
|
| 66 |
+
"""
|
| 67 |
+
super(InceptionV3, self).__init__()
|
| 68 |
+
|
| 69 |
+
self.resize_input = resize_input
|
| 70 |
+
self.normalize_input = normalize_input
|
| 71 |
+
self.output_blocks = sorted(output_blocks)
|
| 72 |
+
self.last_needed_block = max(output_blocks)
|
| 73 |
+
|
| 74 |
+
assert self.last_needed_block <= 3, \
|
| 75 |
+
'Last possible output block index is 3'
|
| 76 |
+
|
| 77 |
+
self.blocks = nn.ModuleList()
|
| 78 |
+
|
| 79 |
+
if use_fid_inception:
|
| 80 |
+
inception = fid_inception_v3()
|
| 81 |
+
else:
|
| 82 |
+
inception = _inception_v3(weights='DEFAULT')
|
| 83 |
+
|
| 84 |
+
# Block 0: input to maxpool1
|
| 85 |
+
block0 = [
|
| 86 |
+
inception.Conv2d_1a_3x3,
|
| 87 |
+
inception.Conv2d_2a_3x3,
|
| 88 |
+
inception.Conv2d_2b_3x3,
|
| 89 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
| 90 |
+
]
|
| 91 |
+
self.blocks.append(nn.Sequential(*block0))
|
| 92 |
+
|
| 93 |
+
# Block 1: maxpool1 to maxpool2
|
| 94 |
+
if self.last_needed_block >= 1:
|
| 95 |
+
block1 = [
|
| 96 |
+
inception.Conv2d_3b_1x1,
|
| 97 |
+
inception.Conv2d_4a_3x3,
|
| 98 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
| 99 |
+
]
|
| 100 |
+
self.blocks.append(nn.Sequential(*block1))
|
| 101 |
+
|
| 102 |
+
# Block 2: maxpool2 to aux classifier
|
| 103 |
+
if self.last_needed_block >= 2:
|
| 104 |
+
block2 = [
|
| 105 |
+
inception.Mixed_5b,
|
| 106 |
+
inception.Mixed_5c,
|
| 107 |
+
inception.Mixed_5d,
|
| 108 |
+
inception.Mixed_6a,
|
| 109 |
+
inception.Mixed_6b,
|
| 110 |
+
inception.Mixed_6c,
|
| 111 |
+
inception.Mixed_6d,
|
| 112 |
+
inception.Mixed_6e,
|
| 113 |
+
]
|
| 114 |
+
self.blocks.append(nn.Sequential(*block2))
|
| 115 |
+
|
| 116 |
+
# Block 3: aux classifier to final avgpool
|
| 117 |
+
if self.last_needed_block >= 3:
|
| 118 |
+
block3 = [
|
| 119 |
+
inception.Mixed_7a,
|
| 120 |
+
inception.Mixed_7b,
|
| 121 |
+
inception.Mixed_7c,
|
| 122 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
| 123 |
+
]
|
| 124 |
+
self.blocks.append(nn.Sequential(*block3))
|
| 125 |
+
|
| 126 |
+
for param in self.parameters():
|
| 127 |
+
param.requires_grad = requires_grad
|
| 128 |
+
|
| 129 |
+
def forward(self, inp):
|
| 130 |
+
"""Get Inception feature maps
|
| 131 |
+
|
| 132 |
+
Parameters
|
| 133 |
+
----------
|
| 134 |
+
inp : torch.autograd.Variable
|
| 135 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
| 136 |
+
range (0, 1)
|
| 137 |
+
|
| 138 |
+
Returns
|
| 139 |
+
-------
|
| 140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
| 141 |
+
block, sorted ascending by index
|
| 142 |
+
"""
|
| 143 |
+
outp = []
|
| 144 |
+
x = inp
|
| 145 |
+
|
| 146 |
+
if self.resize_input:
|
| 147 |
+
x = F.interpolate(x,
|
| 148 |
+
size=(299, 299),
|
| 149 |
+
mode='bilinear',
|
| 150 |
+
align_corners=False)
|
| 151 |
+
|
| 152 |
+
if self.normalize_input:
|
| 153 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
| 154 |
+
|
| 155 |
+
for idx, block in enumerate(self.blocks):
|
| 156 |
+
x = block(x)
|
| 157 |
+
if idx in self.output_blocks:
|
| 158 |
+
outp.append(x)
|
| 159 |
+
|
| 160 |
+
if idx == self.last_needed_block:
|
| 161 |
+
break
|
| 162 |
+
|
| 163 |
+
return outp
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _inception_v3(*args, **kwargs):
|
| 167 |
+
"""Wraps `torchvision.models.inception_v3`"""
|
| 168 |
+
try:
|
| 169 |
+
version = tuple(map(int, torchvision.__version__.split('.')[:2]))
|
| 170 |
+
except ValueError:
|
| 171 |
+
# Just a caution against weird version strings
|
| 172 |
+
version = (0,)
|
| 173 |
+
|
| 174 |
+
# Skips default weight inititialization if supported by torchvision
|
| 175 |
+
# version. See https://github.com/mseitzer/pytorch-fid/issues/28.
|
| 176 |
+
if version >= (0, 6):
|
| 177 |
+
kwargs['init_weights'] = False
|
| 178 |
+
|
| 179 |
+
# Backwards compatibility: `weights` argument was handled by `pretrained`
|
| 180 |
+
# argument prior to version 0.13.
|
| 181 |
+
if version < (0, 13) and 'weights' in kwargs:
|
| 182 |
+
if kwargs['weights'] == 'DEFAULT':
|
| 183 |
+
kwargs['pretrained'] = True
|
| 184 |
+
elif kwargs['weights'] is None:
|
| 185 |
+
kwargs['pretrained'] = False
|
| 186 |
+
else:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
'weights=={} not supported in torchvision {}'.format(
|
| 189 |
+
kwargs['weights'], torchvision.__version__
|
| 190 |
+
)
|
| 191 |
+
)
|
| 192 |
+
del kwargs['weights']
|
| 193 |
+
|
| 194 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def fid_inception_v3():
|
| 198 |
+
"""Build pretrained Inception model for FID computation
|
| 199 |
+
|
| 200 |
+
The Inception model for FID computation uses a different set of weights
|
| 201 |
+
and has a slightly different structure than torchvision's Inception.
|
| 202 |
+
|
| 203 |
+
This method first constructs torchvision's Inception and then patches the
|
| 204 |
+
necessary parts that are different in the FID Inception model.
|
| 205 |
+
"""
|
| 206 |
+
inception = _inception_v3(num_classes=1008,
|
| 207 |
+
aux_logits=False,
|
| 208 |
+
weights=None)
|
| 209 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
| 210 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
| 211 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
| 212 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
| 213 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
| 214 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
| 215 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
| 216 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
| 217 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
| 218 |
+
|
| 219 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
| 220 |
+
inception.load_state_dict(state_dict)
|
| 221 |
+
return inception
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
| 225 |
+
"""InceptionA block patched for FID computation"""
|
| 226 |
+
def __init__(self, in_channels, pool_features):
|
| 227 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
| 228 |
+
|
| 229 |
+
def forward(self, x):
|
| 230 |
+
branch1x1 = self.branch1x1(x)
|
| 231 |
+
|
| 232 |
+
branch5x5 = self.branch5x5_1(x)
|
| 233 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
| 234 |
+
|
| 235 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 236 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 237 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
| 238 |
+
|
| 239 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
| 240 |
+
# its average calculation
|
| 241 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
| 242 |
+
count_include_pad=False)
|
| 243 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 244 |
+
|
| 245 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
| 246 |
+
return torch.cat(outputs, 1)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
| 250 |
+
"""InceptionC block patched for FID computation"""
|
| 251 |
+
def __init__(self, in_channels, channels_7x7):
|
| 252 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
| 253 |
+
|
| 254 |
+
def forward(self, x):
|
| 255 |
+
branch1x1 = self.branch1x1(x)
|
| 256 |
+
|
| 257 |
+
branch7x7 = self.branch7x7_1(x)
|
| 258 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
| 259 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
| 260 |
+
|
| 261 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
| 262 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
| 263 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
| 264 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
| 265 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
| 266 |
+
|
| 267 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
| 268 |
+
# its average calculation
|
| 269 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
| 270 |
+
count_include_pad=False)
|
| 271 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 272 |
+
|
| 273 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
| 274 |
+
return torch.cat(outputs, 1)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
| 278 |
+
"""First InceptionE block patched for FID computation"""
|
| 279 |
+
def __init__(self, in_channels):
|
| 280 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
| 281 |
+
|
| 282 |
+
def forward(self, x):
|
| 283 |
+
branch1x1 = self.branch1x1(x)
|
| 284 |
+
|
| 285 |
+
branch3x3 = self.branch3x3_1(x)
|
| 286 |
+
branch3x3 = [
|
| 287 |
+
self.branch3x3_2a(branch3x3),
|
| 288 |
+
self.branch3x3_2b(branch3x3),
|
| 289 |
+
]
|
| 290 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
| 291 |
+
|
| 292 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 293 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 294 |
+
branch3x3dbl = [
|
| 295 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
| 296 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
| 297 |
+
]
|
| 298 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
| 299 |
+
|
| 300 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
| 301 |
+
# its average calculation
|
| 302 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
| 303 |
+
count_include_pad=False)
|
| 304 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 305 |
+
|
| 306 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
| 307 |
+
return torch.cat(outputs, 1)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
| 311 |
+
"""Second InceptionE block patched for FID computation"""
|
| 312 |
+
def __init__(self, in_channels):
|
| 313 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
| 314 |
+
|
| 315 |
+
def forward(self, x):
|
| 316 |
+
branch1x1 = self.branch1x1(x)
|
| 317 |
+
|
| 318 |
+
branch3x3 = self.branch3x3_1(x)
|
| 319 |
+
branch3x3 = [
|
| 320 |
+
self.branch3x3_2a(branch3x3),
|
| 321 |
+
self.branch3x3_2b(branch3x3),
|
| 322 |
+
]
|
| 323 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
| 324 |
+
|
| 325 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 326 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 327 |
+
branch3x3dbl = [
|
| 328 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
| 329 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
| 330 |
+
]
|
| 331 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
| 332 |
+
|
| 333 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
| 334 |
+
# pooling. This is likely an error in this specific Inception
|
| 335 |
+
# implementation, as other Inception models use average pooling here
|
| 336 |
+
# (which matches the description in the paper).
|
| 337 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
| 338 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 339 |
+
|
| 340 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
| 341 |
+
return torch.cat(outputs, 1)
|
libs/modules/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
libs/modules/edge_map/DoG/XDoG.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
from scipy import ndimage as ndi
|
| 9 |
+
from skimage import filters
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class XDoG:
|
| 13 |
+
|
| 14 |
+
def __init__(self,
|
| 15 |
+
gamma=0.98,
|
| 16 |
+
phi=200,
|
| 17 |
+
eps=-0.1,
|
| 18 |
+
sigma=0.8,
|
| 19 |
+
k=10,
|
| 20 |
+
binarize: bool = True):
|
| 21 |
+
"""
|
| 22 |
+
XDoG algorithm.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
gamma: Control the size of the Gaussian filter
|
| 26 |
+
phi: Control changes in edge strength
|
| 27 |
+
eps: Threshold for controlling edge strength
|
| 28 |
+
sigma: The standard deviation of the Gaussian filter controls the degree of smoothness
|
| 29 |
+
k: Control the size ratio of Gaussian filter, (k=10 or k=1.6)
|
| 30 |
+
binarize(bool): Whether to binarize the output
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
super(XDoG, self).__init__()
|
| 34 |
+
|
| 35 |
+
self.gamma = gamma
|
| 36 |
+
assert 0 <= self.gamma <= 1
|
| 37 |
+
|
| 38 |
+
self.phi = phi
|
| 39 |
+
assert 0 <= self.phi <= 1500
|
| 40 |
+
|
| 41 |
+
self.eps = eps
|
| 42 |
+
assert -1 <= self.eps <= 1
|
| 43 |
+
|
| 44 |
+
self.sigma = sigma
|
| 45 |
+
assert 0.1 <= self.sigma <= 10
|
| 46 |
+
|
| 47 |
+
self.k = k
|
| 48 |
+
assert 1 <= self.k <= 100
|
| 49 |
+
|
| 50 |
+
self.binarize = binarize
|
| 51 |
+
|
| 52 |
+
def __call__(self, img):
|
| 53 |
+
# to gray if image is not already grayscale
|
| 54 |
+
if len(img.shape) == 3 and img.shape[2] == 3:
|
| 55 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 56 |
+
elif len(img.shape) == 3 and img.shape[2] == 4:
|
| 57 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY)
|
| 58 |
+
|
| 59 |
+
if np.isnan(img).any():
|
| 60 |
+
img[np.isnan(img)] = np.mean(img[~np.isnan(img)])
|
| 61 |
+
|
| 62 |
+
# gaussian filter
|
| 63 |
+
imf1 = ndi.gaussian_filter(img, self.sigma)
|
| 64 |
+
imf2 = ndi.gaussian_filter(img, self.sigma * self.k)
|
| 65 |
+
imdiff = imf1 - self.gamma * imf2
|
| 66 |
+
|
| 67 |
+
# XDoG
|
| 68 |
+
imdiff = (imdiff < self.eps) * 1.0 + (imdiff >= self.eps) * (1.0 + np.tanh(self.phi * imdiff))
|
| 69 |
+
|
| 70 |
+
# normalize
|
| 71 |
+
imdiff -= imdiff.min()
|
| 72 |
+
imdiff /= imdiff.max()
|
| 73 |
+
|
| 74 |
+
if self.binarize:
|
| 75 |
+
th = filters.threshold_otsu(imdiff)
|
| 76 |
+
imdiff = (imdiff >= th).astype('float32')
|
| 77 |
+
|
| 78 |
+
return imdiff
|
libs/modules/edge_map/DoG/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
from .XDoG import XDoG
|
| 7 |
+
|
| 8 |
+
__all__ = ['XDoG']
|
libs/modules/edge_map/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
libs/modules/edge_map/canny/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CannyDetector:
|
| 10 |
+
|
| 11 |
+
def __call__(self, img, low_threshold, high_threshold, L2gradient=False):
|
| 12 |
+
return cv2.Canny(img, low_threshold, high_threshold, L2gradient)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = ['CannyDetector']
|
libs/modules/edge_map/image_grads/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
from .laplacian import LaplacianDetector
|
| 7 |
+
|
| 8 |
+
__all__ = ['LaplacianDetector']
|
libs/modules/edge_map/image_grads/laplacian.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LaplacianDetector:
|
| 11 |
+
|
| 12 |
+
def __call__(self, img):
|
| 13 |
+
return cv2.Laplacian(img, cv2.CV_64F)
|
libs/modules/ema.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Description: EMA model
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
__all__ = ['EMA']
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class EMA(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Implements exponential moving average shadowing for your model.
|
| 16 |
+
Utilizes an inverse decay schedule to manage longer term training runs.
|
| 17 |
+
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
| 18 |
+
@crowsonkb's notes on EMA Warmup:
|
| 19 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
| 20 |
+
good values for models you plan to train for a million or more steps (reaches decay
|
| 21 |
+
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
| 22 |
+
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
| 23 |
+
215.4k steps).
|
| 24 |
+
Args:
|
| 25 |
+
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
| 26 |
+
power (float): Exponential factor of EMA warmup. Default: 1.
|
| 27 |
+
min_value (float): The minimum EMA decay rate. Default: 0.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
model,
|
| 33 |
+
# if your model has lazylinears or other types of non-deepcopyable modules,
|
| 34 |
+
# you can pass in your own ema model
|
| 35 |
+
ema_model=None,
|
| 36 |
+
beta=0.9999,
|
| 37 |
+
update_after_step=100,
|
| 38 |
+
update_every=10,
|
| 39 |
+
inv_gamma=1.0,
|
| 40 |
+
power=2 / 3,
|
| 41 |
+
min_value=0.0,
|
| 42 |
+
param_or_buffer_names_no_ema=set(),
|
| 43 |
+
ignore_names=set(),
|
| 44 |
+
ignore_startswith_names=set(),
|
| 45 |
+
# set this to False if you do not wish for the online model to be
|
| 46 |
+
# saved along with the ema model (managed externally)
|
| 47 |
+
include_online_model=True
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.beta = beta
|
| 51 |
+
|
| 52 |
+
# whether to include the online model within the module tree, so that state_dict also saves it
|
| 53 |
+
self.include_online_model = include_online_model
|
| 54 |
+
|
| 55 |
+
if include_online_model:
|
| 56 |
+
self.online_model = model
|
| 57 |
+
else:
|
| 58 |
+
self.online_model = [model] # hack
|
| 59 |
+
|
| 60 |
+
# ema model
|
| 61 |
+
self.ema_model = ema_model
|
| 62 |
+
|
| 63 |
+
if not exists(self.ema_model):
|
| 64 |
+
try:
|
| 65 |
+
self.ema_model = copy.deepcopy(model)
|
| 66 |
+
except:
|
| 67 |
+
print('Your model was not copyable. Please make sure you are not using any LazyLinear')
|
| 68 |
+
exit()
|
| 69 |
+
|
| 70 |
+
self.ema_model.requires_grad_(False)
|
| 71 |
+
|
| 72 |
+
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param.dtype == torch.float}
|
| 73 |
+
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer.dtype == torch.float}
|
| 74 |
+
|
| 75 |
+
self.update_every = update_every
|
| 76 |
+
self.update_after_step = update_after_step
|
| 77 |
+
|
| 78 |
+
self.inv_gamma = inv_gamma
|
| 79 |
+
self.power = power
|
| 80 |
+
self.min_value = min_value
|
| 81 |
+
|
| 82 |
+
assert isinstance(param_or_buffer_names_no_ema, (set, list))
|
| 83 |
+
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
|
| 84 |
+
|
| 85 |
+
self.ignore_names = ignore_names
|
| 86 |
+
self.ignore_startswith_names = ignore_startswith_names
|
| 87 |
+
|
| 88 |
+
self.register_buffer('initted', torch.Tensor([False]))
|
| 89 |
+
self.register_buffer('step', torch.tensor([0]))
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def model(self):
|
| 93 |
+
return self.online_model if self.include_online_model else self.online_model[0]
|
| 94 |
+
|
| 95 |
+
def restore_ema_model_device(self):
|
| 96 |
+
device = self.initted.device
|
| 97 |
+
self.ema_model.to(device)
|
| 98 |
+
|
| 99 |
+
def get_params_iter(self, model):
|
| 100 |
+
for name, param in model.named_parameters():
|
| 101 |
+
if name not in self.parameter_names:
|
| 102 |
+
continue
|
| 103 |
+
yield name, param
|
| 104 |
+
|
| 105 |
+
def get_buffers_iter(self, model):
|
| 106 |
+
for name, buffer in model.named_buffers():
|
| 107 |
+
if name not in self.buffer_names:
|
| 108 |
+
continue
|
| 109 |
+
yield name, buffer
|
| 110 |
+
|
| 111 |
+
def copy_params_from_model_to_ema(self):
|
| 112 |
+
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model),
|
| 113 |
+
self.get_params_iter(self.model)):
|
| 114 |
+
ma_params.data.copy_(current_params.data)
|
| 115 |
+
|
| 116 |
+
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model),
|
| 117 |
+
self.get_buffers_iter(self.model)):
|
| 118 |
+
ma_buffers.data.copy_(current_buffers.data)
|
| 119 |
+
|
| 120 |
+
def get_current_decay(self):
|
| 121 |
+
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value=0.)
|
| 122 |
+
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
| 123 |
+
|
| 124 |
+
if epoch <= 0:
|
| 125 |
+
return 0.
|
| 126 |
+
|
| 127 |
+
return clamp(value, min_value=self.min_value, max_value=self.beta)
|
| 128 |
+
|
| 129 |
+
def update(self):
|
| 130 |
+
step = self.step.item()
|
| 131 |
+
self.step += 1
|
| 132 |
+
|
| 133 |
+
if (step % self.update_every) != 0:
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
if step <= self.update_after_step:
|
| 137 |
+
self.copy_params_from_model_to_ema()
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
if not self.initted.item():
|
| 141 |
+
self.copy_params_from_model_to_ema()
|
| 142 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
| 143 |
+
|
| 144 |
+
self.update_moving_average(self.ema_model, self.model)
|
| 145 |
+
|
| 146 |
+
@torch.no_grad()
|
| 147 |
+
def update_moving_average(self, ma_model, current_model):
|
| 148 |
+
current_decay = self.get_current_decay()
|
| 149 |
+
|
| 150 |
+
for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model),
|
| 151 |
+
self.get_params_iter(ma_model)):
|
| 152 |
+
if name in self.ignore_names:
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
if name in self.param_or_buffer_names_no_ema:
|
| 159 |
+
ma_params.data.copy_(current_params.data)
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
ma_params.data.lerp_(current_params.data, 1. - current_decay)
|
| 163 |
+
|
| 164 |
+
for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model),
|
| 165 |
+
self.get_buffers_iter(ma_model)):
|
| 166 |
+
if name in self.ignore_names:
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
if name in self.param_or_buffer_names_no_ema:
|
| 173 |
+
ma_buffer.data.copy_(current_buffer.data)
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
ma_buffer.data.lerp_(current_buffer.data, 1. - current_decay)
|
| 177 |
+
|
| 178 |
+
def __call__(self, *args, **kwargs):
|
| 179 |
+
return self.ema_model(*args, **kwargs)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def exists(val):
|
| 183 |
+
return val is not None
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def is_float_dtype(dtype):
|
| 187 |
+
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def clamp(value, min_value=None, max_value=None):
|
| 191 |
+
assert exists(min_value) or exists(max_value)
|
| 192 |
+
if exists(min_value):
|
| 193 |
+
value = max(value, min_value)
|
| 194 |
+
|
| 195 |
+
if exists(max_value):
|
| 196 |
+
value = min(value, max_value)
|
| 197 |
+
|
| 198 |
+
return value
|
libs/modules/vision/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
from .inception import inception_v3
|
| 7 |
+
from .vgg import VGG
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
'inception_v3',
|
| 11 |
+
'VGG'
|
| 12 |
+
]
|
libs/modules/vision/inception.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
from collections import namedtuple
|
| 7 |
+
import warnings
|
| 8 |
+
from typing import Callable, Any, Optional, Tuple, List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn, Tensor
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
| 14 |
+
|
| 15 |
+
__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
|
| 16 |
+
|
| 17 |
+
model_urls = {
|
| 18 |
+
# Inception v3 ported from TensorFlow
|
| 19 |
+
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth',
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
|
| 23 |
+
InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}
|
| 24 |
+
|
| 25 |
+
# Script annotations failed with _GoogleNetOutputs = namedtuple ...
|
| 26 |
+
# _InceptionOutputs set here for backwards compat
|
| 27 |
+
_InceptionOutputs = InceptionOutputs
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
|
| 31 |
+
r"""Inception v3 model architecture from
|
| 32 |
+
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
|
| 33 |
+
|
| 34 |
+
.. note::
|
| 35 |
+
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
|
| 36 |
+
N x 3 x 299 x 299, so ensure your images are sized accordingly.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 40 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 41 |
+
aux_logits (bool): If True, add an auxiliary branch that can improve training.
|
| 42 |
+
Default: *True*
|
| 43 |
+
transform_input (bool): If True, preprocesses the input according to the method with which it
|
| 44 |
+
was trained on ImageNet. Default: *False*
|
| 45 |
+
"""
|
| 46 |
+
if pretrained:
|
| 47 |
+
if 'transform_input' not in kwargs:
|
| 48 |
+
kwargs['transform_input'] = True
|
| 49 |
+
if 'aux_logits' in kwargs:
|
| 50 |
+
original_aux_logits = kwargs['aux_logits']
|
| 51 |
+
kwargs['aux_logits'] = True
|
| 52 |
+
else:
|
| 53 |
+
original_aux_logits = True
|
| 54 |
+
kwargs['init_weights'] = False # we are loading weights from a pretrained model
|
| 55 |
+
model = Inception3(**kwargs)
|
| 56 |
+
state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
|
| 57 |
+
progress=progress)
|
| 58 |
+
model.load_state_dict(state_dict)
|
| 59 |
+
if not original_aux_logits:
|
| 60 |
+
model.aux_logits = False
|
| 61 |
+
model.AuxLogits = None
|
| 62 |
+
return model
|
| 63 |
+
|
| 64 |
+
return Inception3(**kwargs)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Inception3(nn.Module):
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
num_classes: int = 1000,
|
| 72 |
+
aux_logits: bool = True,
|
| 73 |
+
transform_input: bool = False,
|
| 74 |
+
inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
|
| 75 |
+
init_weights: Optional[bool] = None
|
| 76 |
+
) -> None:
|
| 77 |
+
super(Inception3, self).__init__()
|
| 78 |
+
if inception_blocks is None:
|
| 79 |
+
inception_blocks = [
|
| 80 |
+
BasicConv2d, InceptionA, InceptionB, InceptionC,
|
| 81 |
+
InceptionD, InceptionE, InceptionAux
|
| 82 |
+
]
|
| 83 |
+
if init_weights is None:
|
| 84 |
+
warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
|
| 85 |
+
'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
|
| 86 |
+
' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
|
| 87 |
+
init_weights = True
|
| 88 |
+
assert len(inception_blocks) == 7
|
| 89 |
+
conv_block = inception_blocks[0]
|
| 90 |
+
inception_a = inception_blocks[1]
|
| 91 |
+
inception_b = inception_blocks[2]
|
| 92 |
+
inception_c = inception_blocks[3]
|
| 93 |
+
inception_d = inception_blocks[4]
|
| 94 |
+
inception_e = inception_blocks[5]
|
| 95 |
+
inception_aux = inception_blocks[6]
|
| 96 |
+
|
| 97 |
+
self.aux_logits = aux_logits
|
| 98 |
+
self.transform_input = transform_input
|
| 99 |
+
self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
|
| 100 |
+
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
|
| 101 |
+
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
|
| 102 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
|
| 103 |
+
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
|
| 104 |
+
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
|
| 105 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
|
| 106 |
+
self.Mixed_5b = inception_a(192, pool_features=32)
|
| 107 |
+
self.Mixed_5c = inception_a(256, pool_features=64)
|
| 108 |
+
self.Mixed_5d = inception_a(288, pool_features=64)
|
| 109 |
+
self.Mixed_6a = inception_b(288)
|
| 110 |
+
self.Mixed_6b = inception_c(768, channels_7x7=128)
|
| 111 |
+
self.Mixed_6c = inception_c(768, channels_7x7=160)
|
| 112 |
+
self.Mixed_6d = inception_c(768, channels_7x7=160)
|
| 113 |
+
self.Mixed_6e = inception_c(768, channels_7x7=192)
|
| 114 |
+
self.AuxLogits: Optional[nn.Module] = None
|
| 115 |
+
if aux_logits:
|
| 116 |
+
self.AuxLogits = inception_aux(768, num_classes)
|
| 117 |
+
self.Mixed_7a = inception_d(768)
|
| 118 |
+
self.Mixed_7b = inception_e(1280)
|
| 119 |
+
self.Mixed_7c = inception_e(2048)
|
| 120 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 121 |
+
self.dropout = nn.Dropout()
|
| 122 |
+
self.fc = nn.Linear(2048, num_classes)
|
| 123 |
+
if init_weights:
|
| 124 |
+
for m in self.modules():
|
| 125 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 126 |
+
import scipy.stats as stats
|
| 127 |
+
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
|
| 128 |
+
X = stats.truncnorm(-2, 2, scale=stddev)
|
| 129 |
+
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
|
| 130 |
+
values = values.view(m.weight.size())
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
m.weight.copy_(values)
|
| 133 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 134 |
+
nn.init.constant_(m.weight, 1)
|
| 135 |
+
nn.init.constant_(m.bias, 0)
|
| 136 |
+
|
| 137 |
+
def _transform_input(self, x: Tensor) -> Tensor:
|
| 138 |
+
if self.transform_input:
|
| 139 |
+
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
| 140 |
+
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
| 141 |
+
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
| 142 |
+
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
|
| 146 |
+
# N x 3 x 299 x 299
|
| 147 |
+
x = self.Conv2d_1a_3x3(x)
|
| 148 |
+
# N x 32 x 149 x 149
|
| 149 |
+
x = self.Conv2d_2a_3x3(x)
|
| 150 |
+
# N x 32 x 147 x 147
|
| 151 |
+
x = self.Conv2d_2b_3x3(x)
|
| 152 |
+
# N x 64 x 147 x 147
|
| 153 |
+
feat = self.maxpool1(x)
|
| 154 |
+
# N x 64 x 73 x 73
|
| 155 |
+
x = self.Conv2d_3b_1x1(feat)
|
| 156 |
+
# N x 80 x 73 x 73
|
| 157 |
+
x = self.Conv2d_4a_3x3(x)
|
| 158 |
+
# N x 192 x 71 x 71
|
| 159 |
+
x = self.maxpool2(x)
|
| 160 |
+
# N x 192 x 35 x 35
|
| 161 |
+
x = self.Mixed_5b(x)
|
| 162 |
+
# N x 256 x 35 x 35
|
| 163 |
+
x = self.Mixed_5c(x)
|
| 164 |
+
# N x 288 x 35 x 35
|
| 165 |
+
x = self.Mixed_5d(x)
|
| 166 |
+
# N x 288 x 35 x 35
|
| 167 |
+
x = self.Mixed_6a(x)
|
| 168 |
+
# N x 768 x 17 x 17
|
| 169 |
+
x = self.Mixed_6b(x)
|
| 170 |
+
# N x 768 x 17 x 17
|
| 171 |
+
x = self.Mixed_6c(x)
|
| 172 |
+
# N x 768 x 17 x 17
|
| 173 |
+
x = self.Mixed_6d(x)
|
| 174 |
+
# N x 768 x 17 x 17
|
| 175 |
+
x = self.Mixed_6e(x)
|
| 176 |
+
# N x 768 x 17 x 17
|
| 177 |
+
aux: Optional[Tensor] = None
|
| 178 |
+
if self.AuxLogits is not None:
|
| 179 |
+
if self.training:
|
| 180 |
+
aux = self.AuxLogits(x)
|
| 181 |
+
# N x 768 x 17 x 17
|
| 182 |
+
x = self.Mixed_7a(x)
|
| 183 |
+
# N x 1280 x 8 x 8
|
| 184 |
+
x = self.Mixed_7b(x)
|
| 185 |
+
# N x 2048 x 8 x 8
|
| 186 |
+
x = self.Mixed_7c(x)
|
| 187 |
+
# N x 2048 x 8 x 8
|
| 188 |
+
# Adaptive average pooling
|
| 189 |
+
x = self.avgpool(x)
|
| 190 |
+
# N x 2048 x 1 x 1
|
| 191 |
+
x = self.dropout(x)
|
| 192 |
+
# N x 2048 x 1 x 1
|
| 193 |
+
x = torch.flatten(x, 1)
|
| 194 |
+
# N x 2048
|
| 195 |
+
x = self.fc(x)
|
| 196 |
+
# N x 1000 (num_classes)
|
| 197 |
+
return feat, x, aux
|
| 198 |
+
|
| 199 |
+
@torch.jit.unused
|
| 200 |
+
def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
|
| 201 |
+
if self.training and self.aux_logits:
|
| 202 |
+
return InceptionOutputs(x, aux)
|
| 203 |
+
else:
|
| 204 |
+
return x # type: ignore[return-value]
|
| 205 |
+
|
| 206 |
+
def forward(self, x: Tensor) -> InceptionOutputs:
|
| 207 |
+
x = self._transform_input(x)
|
| 208 |
+
feat, x, aux = self._forward(x)
|
| 209 |
+
aux_defined = self.training and self.aux_logits
|
| 210 |
+
if torch.jit.is_scripting():
|
| 211 |
+
if not aux_defined:
|
| 212 |
+
warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
|
| 213 |
+
return feat, InceptionOutputs(x, aux)
|
| 214 |
+
else:
|
| 215 |
+
return feat, self.eager_outputs(x, aux)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class InceptionA(nn.Module):
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
in_channels: int,
|
| 223 |
+
pool_features: int,
|
| 224 |
+
conv_block: Optional[Callable[..., nn.Module]] = None
|
| 225 |
+
) -> None:
|
| 226 |
+
super(InceptionA, self).__init__()
|
| 227 |
+
if conv_block is None:
|
| 228 |
+
conv_block = BasicConv2d
|
| 229 |
+
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
|
| 230 |
+
|
| 231 |
+
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
|
| 232 |
+
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
|
| 233 |
+
|
| 234 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
| 235 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
| 236 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
|
| 237 |
+
|
| 238 |
+
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
|
| 239 |
+
|
| 240 |
+
def _forward(self, x: Tensor) -> List[Tensor]:
|
| 241 |
+
branch1x1 = self.branch1x1(x)
|
| 242 |
+
|
| 243 |
+
branch5x5 = self.branch5x5_1(x)
|
| 244 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
| 245 |
+
|
| 246 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 247 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 248 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
| 249 |
+
|
| 250 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
| 251 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 252 |
+
|
| 253 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
| 254 |
+
return outputs
|
| 255 |
+
|
| 256 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 257 |
+
outputs = self._forward(x)
|
| 258 |
+
return torch.cat(outputs, 1)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class InceptionB(nn.Module):
|
| 262 |
+
|
| 263 |
+
def __init__(
|
| 264 |
+
self,
|
| 265 |
+
in_channels: int,
|
| 266 |
+
conv_block: Optional[Callable[..., nn.Module]] = None
|
| 267 |
+
) -> None:
|
| 268 |
+
super(InceptionB, self).__init__()
|
| 269 |
+
if conv_block is None:
|
| 270 |
+
conv_block = BasicConv2d
|
| 271 |
+
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
|
| 272 |
+
|
| 273 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
| 274 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
| 275 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
|
| 276 |
+
|
| 277 |
+
def _forward(self, x: Tensor) -> List[Tensor]:
|
| 278 |
+
branch3x3 = self.branch3x3(x)
|
| 279 |
+
|
| 280 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 281 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 282 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
| 283 |
+
|
| 284 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
| 285 |
+
|
| 286 |
+
outputs = [branch3x3, branch3x3dbl, branch_pool]
|
| 287 |
+
return outputs
|
| 288 |
+
|
| 289 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 290 |
+
outputs = self._forward(x)
|
| 291 |
+
return torch.cat(outputs, 1)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class InceptionC(nn.Module):
|
| 295 |
+
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
in_channels: int,
|
| 299 |
+
channels_7x7: int,
|
| 300 |
+
conv_block: Optional[Callable[..., nn.Module]] = None
|
| 301 |
+
) -> None:
|
| 302 |
+
super(InceptionC, self).__init__()
|
| 303 |
+
if conv_block is None:
|
| 304 |
+
conv_block = BasicConv2d
|
| 305 |
+
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
|
| 306 |
+
|
| 307 |
+
c7 = channels_7x7
|
| 308 |
+
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
|
| 309 |
+
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
| 310 |
+
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
|
| 311 |
+
|
| 312 |
+
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
|
| 313 |
+
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
| 314 |
+
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
| 315 |
+
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
| 316 |
+
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
| 317 |
+
|
| 318 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
| 319 |
+
|
| 320 |
+
def _forward(self, x: Tensor) -> List[Tensor]:
|
| 321 |
+
branch1x1 = self.branch1x1(x)
|
| 322 |
+
|
| 323 |
+
branch7x7 = self.branch7x7_1(x)
|
| 324 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
| 325 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
| 326 |
+
|
| 327 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
| 328 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
| 329 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
| 330 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
| 331 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
| 332 |
+
|
| 333 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
| 334 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 335 |
+
|
| 336 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
| 337 |
+
return outputs
|
| 338 |
+
|
| 339 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 340 |
+
outputs = self._forward(x)
|
| 341 |
+
return torch.cat(outputs, 1)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class InceptionD(nn.Module):
|
| 345 |
+
|
| 346 |
+
def __init__(
|
| 347 |
+
self,
|
| 348 |
+
in_channels: int,
|
| 349 |
+
conv_block: Optional[Callable[..., nn.Module]] = None
|
| 350 |
+
) -> None:
|
| 351 |
+
super(InceptionD, self).__init__()
|
| 352 |
+
if conv_block is None:
|
| 353 |
+
conv_block = BasicConv2d
|
| 354 |
+
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
| 355 |
+
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
|
| 356 |
+
|
| 357 |
+
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
| 358 |
+
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
|
| 359 |
+
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
|
| 360 |
+
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
|
| 361 |
+
|
| 362 |
+
def _forward(self, x: Tensor) -> List[Tensor]:
|
| 363 |
+
branch3x3 = self.branch3x3_1(x)
|
| 364 |
+
branch3x3 = self.branch3x3_2(branch3x3)
|
| 365 |
+
|
| 366 |
+
branch7x7x3 = self.branch7x7x3_1(x)
|
| 367 |
+
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
|
| 368 |
+
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
|
| 369 |
+
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
|
| 370 |
+
|
| 371 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
| 372 |
+
outputs = [branch3x3, branch7x7x3, branch_pool]
|
| 373 |
+
return outputs
|
| 374 |
+
|
| 375 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 376 |
+
outputs = self._forward(x)
|
| 377 |
+
return torch.cat(outputs, 1)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class InceptionE(nn.Module):
|
| 381 |
+
|
| 382 |
+
def __init__(
|
| 383 |
+
self,
|
| 384 |
+
in_channels: int,
|
| 385 |
+
conv_block: Optional[Callable[..., nn.Module]] = None
|
| 386 |
+
) -> None:
|
| 387 |
+
super(InceptionE, self).__init__()
|
| 388 |
+
if conv_block is None:
|
| 389 |
+
conv_block = BasicConv2d
|
| 390 |
+
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
|
| 391 |
+
|
| 392 |
+
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
|
| 393 |
+
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
| 394 |
+
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
| 395 |
+
|
| 396 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
|
| 397 |
+
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
|
| 398 |
+
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
| 399 |
+
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
| 400 |
+
|
| 401 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
| 402 |
+
|
| 403 |
+
def _forward(self, x: Tensor) -> List[Tensor]:
|
| 404 |
+
branch1x1 = self.branch1x1(x)
|
| 405 |
+
|
| 406 |
+
branch3x3 = self.branch3x3_1(x)
|
| 407 |
+
branch3x3 = [
|
| 408 |
+
self.branch3x3_2a(branch3x3),
|
| 409 |
+
self.branch3x3_2b(branch3x3),
|
| 410 |
+
]
|
| 411 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
| 412 |
+
|
| 413 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 414 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 415 |
+
branch3x3dbl = [
|
| 416 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
| 417 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
| 418 |
+
]
|
| 419 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
| 420 |
+
|
| 421 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
| 422 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 423 |
+
|
| 424 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
| 425 |
+
return outputs
|
| 426 |
+
|
| 427 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 428 |
+
outputs = self._forward(x)
|
| 429 |
+
return torch.cat(outputs, 1)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
class InceptionAux(nn.Module):
|
| 433 |
+
|
| 434 |
+
def __init__(
|
| 435 |
+
self,
|
| 436 |
+
in_channels: int,
|
| 437 |
+
num_classes: int,
|
| 438 |
+
conv_block: Optional[Callable[..., nn.Module]] = None
|
| 439 |
+
) -> None:
|
| 440 |
+
super(InceptionAux, self).__init__()
|
| 441 |
+
if conv_block is None:
|
| 442 |
+
conv_block = BasicConv2d
|
| 443 |
+
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
|
| 444 |
+
self.conv1 = conv_block(128, 768, kernel_size=5)
|
| 445 |
+
self.conv1.stddev = 0.01 # type: ignore[assignment]
|
| 446 |
+
self.fc = nn.Linear(768, num_classes)
|
| 447 |
+
self.fc.stddev = 0.001 # type: ignore[assignment]
|
| 448 |
+
|
| 449 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 450 |
+
# N x 768 x 17 x 17
|
| 451 |
+
x = F.avg_pool2d(x, kernel_size=5, stride=3)
|
| 452 |
+
# N x 768 x 5 x 5
|
| 453 |
+
x = self.conv0(x)
|
| 454 |
+
# N x 128 x 5 x 5
|
| 455 |
+
x = self.conv1(x)
|
| 456 |
+
# N x 768 x 1 x 1
|
| 457 |
+
# Adaptive average pooling
|
| 458 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
| 459 |
+
# N x 768 x 1 x 1
|
| 460 |
+
x = torch.flatten(x, 1)
|
| 461 |
+
# N x 768
|
| 462 |
+
x = self.fc(x)
|
| 463 |
+
# N x 1000
|
| 464 |
+
return x
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
class BasicConv2d(nn.Module):
|
| 468 |
+
|
| 469 |
+
def __init__(
|
| 470 |
+
self,
|
| 471 |
+
in_channels: int,
|
| 472 |
+
out_channels: int,
|
| 473 |
+
**kwargs: Any
|
| 474 |
+
) -> None:
|
| 475 |
+
super(BasicConv2d, self).__init__()
|
| 476 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
| 477 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
| 478 |
+
|
| 479 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 480 |
+
x = self.conv(x)
|
| 481 |
+
x = self.bn(x)
|
| 482 |
+
return F.relu(x, inplace=True)
|
libs/modules/vision/vgg.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
from typing import Union, List, Dict, Any, cast
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
|
| 14 |
+
'vgg19_bn', 'vgg19',
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
model_urls = {
|
| 18 |
+
'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth',
|
| 19 |
+
'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth',
|
| 20 |
+
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
|
| 21 |
+
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
|
| 22 |
+
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
|
| 23 |
+
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
|
| 24 |
+
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
|
| 25 |
+
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class VGG(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
features: nn.Module,
|
| 34 |
+
num_classes: int = 1000,
|
| 35 |
+
init_weights: bool = True
|
| 36 |
+
) -> None:
|
| 37 |
+
super(VGG, self).__init__()
|
| 38 |
+
self.features = features
|
| 39 |
+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
| 40 |
+
self.classifier = nn.Sequential(
|
| 41 |
+
nn.Linear(512 * 7 * 7, 4096),
|
| 42 |
+
nn.ReLU(True),
|
| 43 |
+
nn.Dropout(),
|
| 44 |
+
nn.Linear(4096, 4096),
|
| 45 |
+
nn.ReLU(True),
|
| 46 |
+
nn.Dropout(),
|
| 47 |
+
nn.Linear(4096, num_classes),
|
| 48 |
+
)
|
| 49 |
+
if init_weights:
|
| 50 |
+
self._initialize_weights()
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor):
|
| 53 |
+
feat = self.features(x)
|
| 54 |
+
x = self.avgpool(feat)
|
| 55 |
+
x = torch.flatten(x, 1)
|
| 56 |
+
x = self.classifier(x)
|
| 57 |
+
return feat, x
|
| 58 |
+
|
| 59 |
+
def _initialize_weights(self) -> None:
|
| 60 |
+
for m in self.modules():
|
| 61 |
+
if isinstance(m, nn.Conv2d):
|
| 62 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 63 |
+
if m.bias is not None:
|
| 64 |
+
nn.init.constant_(m.bias, 0)
|
| 65 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 66 |
+
nn.init.constant_(m.weight, 1)
|
| 67 |
+
nn.init.constant_(m.bias, 0)
|
| 68 |
+
elif isinstance(m, nn.Linear):
|
| 69 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 70 |
+
nn.init.constant_(m.bias, 0)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
|
| 74 |
+
layers: List[nn.Module] = []
|
| 75 |
+
in_channels = 3
|
| 76 |
+
for v in cfg:
|
| 77 |
+
if v == 'M':
|
| 78 |
+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
| 79 |
+
else:
|
| 80 |
+
v = cast(int, v)
|
| 81 |
+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
| 82 |
+
if batch_norm:
|
| 83 |
+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
| 84 |
+
else:
|
| 85 |
+
layers += [conv2d, nn.ReLU(inplace=True)]
|
| 86 |
+
in_channels = v
|
| 87 |
+
return nn.Sequential(*layers)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
cfgs: Dict[str, List[Union[str, int]]] = {
|
| 91 |
+
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
| 92 |
+
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
| 93 |
+
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
| 94 |
+
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
|
| 99 |
+
if pretrained:
|
| 100 |
+
kwargs['init_weights'] = False
|
| 101 |
+
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
|
| 102 |
+
if pretrained:
|
| 103 |
+
state_dict = load_state_dict_from_url(model_urls[arch],
|
| 104 |
+
progress=progress)
|
| 105 |
+
model.load_state_dict(state_dict)
|
| 106 |
+
return model
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
| 110 |
+
r"""VGG 11-layer model (configuration "A") from
|
| 111 |
+
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 115 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 116 |
+
"""
|
| 117 |
+
return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
| 121 |
+
r"""VGG 11-layer model (configuration "A") with batch normalization
|
| 122 |
+
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 126 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 127 |
+
"""
|
| 128 |
+
return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
| 132 |
+
r"""VGG 13-layer model (configuration "B")
|
| 133 |
+
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 137 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 138 |
+
"""
|
| 139 |
+
return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
| 143 |
+
r"""VGG 13-layer model (configuration "B") with batch normalization
|
| 144 |
+
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 148 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 149 |
+
"""
|
| 150 |
+
return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
| 154 |
+
r"""VGG 16-layer model (configuration "D")
|
| 155 |
+
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 159 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 160 |
+
"""
|
| 161 |
+
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
| 165 |
+
r"""VGG 16-layer model (configuration "D") with batch normalization
|
| 166 |
+
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 170 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 171 |
+
"""
|
| 172 |
+
return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
| 176 |
+
r"""VGG 19-layer model (configuration "E")
|
| 177 |
+
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 181 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 182 |
+
"""
|
| 183 |
+
return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
| 187 |
+
r"""VGG 19-layer model (configuration 'E') with batch normalization
|
| 188 |
+
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 192 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 193 |
+
"""
|
| 194 |
+
return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
|
libs/modules/visual/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
libs/modules/visual/imshow.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
|
| 6 |
+
import pathlib
|
| 7 |
+
from typing import Union, List, Text, BinaryIO, AnyStr
|
| 8 |
+
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
from torchvision.utils import make_grid
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
'sample2pil_transforms',
|
| 16 |
+
'pt2numpy_transforms',
|
| 17 |
+
'plt_pt_img',
|
| 18 |
+
'save_grid_images_and_labels',
|
| 19 |
+
'save_grid_images_and_captions',
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
# generate sample to PIL images
|
| 23 |
+
sample2pil_transforms = transforms.Compose([
|
| 24 |
+
# unnormalizing to [0,1]
|
| 25 |
+
transforms.Lambda(lambda t: torch.clamp((t + 1) / 2, min=0.0, max=1.0)),
|
| 26 |
+
# Add 0.5 after unnormalizing to [0, 255]
|
| 27 |
+
transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)),
|
| 28 |
+
# CHW to HWC
|
| 29 |
+
transforms.Lambda(lambda t: t.permute(1, 2, 0)),
|
| 30 |
+
# to numpy ndarray, dtype int8
|
| 31 |
+
transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()),
|
| 32 |
+
# Converts a numpy ndarray of shape H x W x C to a PIL Image
|
| 33 |
+
transforms.ToPILImage(),
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
# generate sample to PIL images
|
| 37 |
+
pt2numpy_transforms = transforms.Compose([
|
| 38 |
+
# Add 0.5 after unnormalizing to [0, 255]
|
| 39 |
+
transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)),
|
| 40 |
+
# CHW to HWC
|
| 41 |
+
transforms.Lambda(lambda t: t.permute(1, 2, 0)),
|
| 42 |
+
# to numpy ndarray, dtype int8
|
| 43 |
+
transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()),
|
| 44 |
+
])
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def plt_pt_img(
|
| 48 |
+
pt_img: torch.Tensor,
|
| 49 |
+
save_path: AnyStr = None,
|
| 50 |
+
title: AnyStr = None,
|
| 51 |
+
dpi: int = 300
|
| 52 |
+
):
|
| 53 |
+
grid = make_grid(pt_img, normalize=True, pad_value=2)
|
| 54 |
+
ndarr = pt2numpy_transforms(grid)
|
| 55 |
+
plt.imshow(ndarr)
|
| 56 |
+
plt.axis("off")
|
| 57 |
+
plt.tight_layout()
|
| 58 |
+
if title is not None:
|
| 59 |
+
plt.title(f"{title}")
|
| 60 |
+
|
| 61 |
+
plt.show()
|
| 62 |
+
if save_path is not None:
|
| 63 |
+
plt.savefig(save_path, dpi=dpi)
|
| 64 |
+
|
| 65 |
+
plt.close()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@torch.no_grad()
|
| 69 |
+
def save_grid_images_and_labels(
|
| 70 |
+
images: Union[torch.Tensor, List[torch.Tensor]],
|
| 71 |
+
probs: Union[torch.Tensor, List[torch.Tensor]],
|
| 72 |
+
labels: Union[torch.Tensor, List[torch.Tensor]],
|
| 73 |
+
classes: Union[torch.Tensor, List[torch.Tensor]],
|
| 74 |
+
fp: Union[Text, pathlib.Path, BinaryIO],
|
| 75 |
+
nrow: int = 4,
|
| 76 |
+
normalize: bool = True
|
| 77 |
+
) -> None:
|
| 78 |
+
"""Save a given Tensor into an image file.
|
| 79 |
+
"""
|
| 80 |
+
num_images = len(images)
|
| 81 |
+
num_rows, num_cols = _get_subplot_shape(num_images, nrow)
|
| 82 |
+
|
| 83 |
+
fig = plt.figure(figsize=(25, 20))
|
| 84 |
+
|
| 85 |
+
for i in range(num_images):
|
| 86 |
+
ax = fig.add_subplot(num_rows, num_cols, i + 1)
|
| 87 |
+
|
| 88 |
+
image, true_label, prob = images[i], labels[i], probs[i]
|
| 89 |
+
|
| 90 |
+
true_prob = prob[true_label]
|
| 91 |
+
incorrect_prob, incorrect_label = torch.max(prob, dim=0)
|
| 92 |
+
true_class = classes[true_label]
|
| 93 |
+
|
| 94 |
+
incorrect_class = classes[incorrect_label]
|
| 95 |
+
|
| 96 |
+
if normalize:
|
| 97 |
+
image = sample2pil_transforms(image)
|
| 98 |
+
|
| 99 |
+
ax.imshow(image)
|
| 100 |
+
title = f'true label: {true_class} ({true_prob:.3f})\n ' \
|
| 101 |
+
f'pred label: {incorrect_class} ({incorrect_prob:.3f})'
|
| 102 |
+
ax.set_title(title, fontsize=20)
|
| 103 |
+
ax.axis('off')
|
| 104 |
+
|
| 105 |
+
fig.subplots_adjust(hspace=0.3)
|
| 106 |
+
|
| 107 |
+
plt.savefig(fp)
|
| 108 |
+
plt.close()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@torch.no_grad()
|
| 112 |
+
def save_grid_images_and_captions(
|
| 113 |
+
images: Union[torch.Tensor, List[torch.Tensor]],
|
| 114 |
+
captions: List,
|
| 115 |
+
fp: Union[Text, pathlib.Path, BinaryIO],
|
| 116 |
+
nrow: int = 4,
|
| 117 |
+
normalize: bool = True
|
| 118 |
+
) -> None:
|
| 119 |
+
"""
|
| 120 |
+
Save a grid of images and their captions into an image file.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
images (Union[torch.Tensor, List[torch.Tensor]]): A list of images to display.
|
| 124 |
+
captions (List): A list of captions for each image.
|
| 125 |
+
fp (Union[Text, pathlib.Path, BinaryIO]): The file path to save the image to.
|
| 126 |
+
nrow (int, optional): The number of images to display in each row. Defaults to 4.
|
| 127 |
+
normalize (bool, optional): Whether to normalize the image or not. Defaults to False.
|
| 128 |
+
"""
|
| 129 |
+
num_images = len(images)
|
| 130 |
+
num_rows, num_cols = _get_subplot_shape(num_images, nrow)
|
| 131 |
+
|
| 132 |
+
fig = plt.figure(figsize=(25, 20))
|
| 133 |
+
|
| 134 |
+
for i in range(num_images):
|
| 135 |
+
ax = fig.add_subplot(num_rows, num_cols, i + 1)
|
| 136 |
+
image, caption = images[i], captions[i]
|
| 137 |
+
|
| 138 |
+
if normalize:
|
| 139 |
+
image = sample2pil_transforms(image)
|
| 140 |
+
|
| 141 |
+
ax.imshow(image)
|
| 142 |
+
title = f'"{caption}"' if num_images > 1 else f'"{captions}"'
|
| 143 |
+
title = _insert_newline(title)
|
| 144 |
+
ax.set_title(title, fontsize=20)
|
| 145 |
+
ax.axis('off')
|
| 146 |
+
|
| 147 |
+
fig.subplots_adjust(hspace=0.3)
|
| 148 |
+
|
| 149 |
+
plt.savefig(fp)
|
| 150 |
+
plt.close()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _get_subplot_shape(num_images, nrow):
|
| 154 |
+
"""
|
| 155 |
+
Calculate the number of rows and columns required to display images in a grid.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
num_images (int): The total number of images to display.
|
| 159 |
+
nrow (int): The maximum number of images to display in each row.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Tuple[int, int]: The number of rows and columns required to display images in a grid.
|
| 163 |
+
"""
|
| 164 |
+
num_cols = min(num_images, nrow)
|
| 165 |
+
num_rows = (num_images + num_cols - 1) // num_cols
|
| 166 |
+
return num_rows, num_cols
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _insert_newline(string, point=9):
|
| 170 |
+
# split by blank
|
| 171 |
+
words = string.split()
|
| 172 |
+
if len(words) <= point:
|
| 173 |
+
return string
|
| 174 |
+
|
| 175 |
+
word_chunks = [words[i:i + point] for i in range(0, len(words), point)]
|
| 176 |
+
new_string = "\n".join(" ".join(chunk) for chunk in word_chunks)
|
| 177 |
+
return new_string
|
libs/modules/visual/video.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
| 5 |
+
from typing import Any, Union
|
| 6 |
+
import pathlib
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_video(num_iter: int,
|
| 12 |
+
save_dir: Union[Any, pathlib.Path],
|
| 13 |
+
video_frame_freq: int = 1,
|
| 14 |
+
fname: str = "rendering_process",
|
| 15 |
+
verbose: bool = True):
|
| 16 |
+
if not isinstance(save_dir, pathlib.Path):
|
| 17 |
+
save_dir = pathlib.Path(save_dir)
|
| 18 |
+
|
| 19 |
+
img_array = []
|
| 20 |
+
for i in range(0, num_iter):
|
| 21 |
+
if i % video_frame_freq == 0 or i == num_iter - 1:
|
| 22 |
+
filename = save_dir / f"iter{i}.png"
|
| 23 |
+
img = cv2.imread(filename.as_posix())
|
| 24 |
+
img_array.append(img)
|
| 25 |
+
|
| 26 |
+
video_name = save_dir / f"{fname}.mp4"
|
| 27 |
+
out = cv2.VideoWriter(
|
| 28 |
+
video_name.as_posix(),
|
| 29 |
+
cv2.VideoWriter_fourcc(*'mp4v'),
|
| 30 |
+
30.0, # fps
|
| 31 |
+
(600, 600) # video size
|
| 32 |
+
)
|
| 33 |
+
for iii in range(len(img_array)):
|
| 34 |
+
out.write(img_array[iii])
|
| 35 |
+
out.release()
|
| 36 |
+
|
| 37 |
+
if verbose:
|
| 38 |
+
print(f"video saved in '{video_name}'.")
|
libs/solver/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
| 3 |
+
# Author: XiMing Xing
|
| 4 |
+
# Description:
|
libs/solver/lr_scheduler.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch optimization for diffusion models."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from enum import Enum
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
from torch.optim import Optimizer
|
| 22 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SchedulerType(Enum):
|
| 26 |
+
LINEAR = "linear"
|
| 27 |
+
COSINE = "cosine"
|
| 28 |
+
COSINE_WITH_RESTARTS = "cosine_with_restarts"
|
| 29 |
+
POLYNOMIAL = "polynomial"
|
| 30 |
+
CONSTANT = "constant"
|
| 31 |
+
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
| 32 |
+
PIECEWISE_CONSTANT = "piecewise_constant"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
| 36 |
+
"""
|
| 37 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 41 |
+
The optimizer for which to schedule the learning rate.
|
| 42 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 43 |
+
The index of the last epoch when resuming training.
|
| 44 |
+
|
| 45 |
+
Return:
|
| 46 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 47 |
+
"""
|
| 48 |
+
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
|
| 52 |
+
"""
|
| 53 |
+
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
| 54 |
+
increases linearly between 0 and the initial lr set in the optimizer.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 58 |
+
The optimizer for which to schedule the learning rate.
|
| 59 |
+
num_warmup_steps (`int`):
|
| 60 |
+
The number of steps for the warmup phase.
|
| 61 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 62 |
+
The index of the last epoch when resuming training.
|
| 63 |
+
|
| 64 |
+
Return:
|
| 65 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def lr_lambda(current_step: int):
|
| 69 |
+
if current_step < num_warmup_steps:
|
| 70 |
+
return float(current_step) / float(max(1.0, num_warmup_steps))
|
| 71 |
+
return 1.0
|
| 72 |
+
|
| 73 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
|
| 77 |
+
"""
|
| 78 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 82 |
+
The optimizer for which to schedule the learning rate.
|
| 83 |
+
step_rules (`string`):
|
| 84 |
+
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
|
| 85 |
+
if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
|
| 86 |
+
steps and multiple 0.005 for the other steps.
|
| 87 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 88 |
+
The index of the last epoch when resuming training.
|
| 89 |
+
|
| 90 |
+
Return:
|
| 91 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
rules_dict = {}
|
| 95 |
+
rule_list = step_rules.split(",")
|
| 96 |
+
for rule_str in rule_list[:-1]:
|
| 97 |
+
value_str, steps_str = rule_str.split(":")
|
| 98 |
+
steps = int(steps_str)
|
| 99 |
+
value = float(value_str)
|
| 100 |
+
rules_dict[steps] = value
|
| 101 |
+
last_lr_multiple = float(rule_list[-1])
|
| 102 |
+
|
| 103 |
+
def create_rules_function(rules_dict, last_lr_multiple):
|
| 104 |
+
def rule_func(steps: int) -> float:
|
| 105 |
+
sorted_steps = sorted(rules_dict.keys())
|
| 106 |
+
for i, sorted_step in enumerate(sorted_steps):
|
| 107 |
+
if steps < sorted_step:
|
| 108 |
+
return rules_dict[sorted_steps[i]]
|
| 109 |
+
return last_lr_multiple
|
| 110 |
+
|
| 111 |
+
return rule_func
|
| 112 |
+
|
| 113 |
+
rules_func = create_rules_function(rules_dict, last_lr_multiple)
|
| 114 |
+
|
| 115 |
+
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
| 119 |
+
"""
|
| 120 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
| 121 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 125 |
+
The optimizer for which to schedule the learning rate.
|
| 126 |
+
num_warmup_steps (`int`):
|
| 127 |
+
The number of steps for the warmup phase.
|
| 128 |
+
num_training_steps (`int`):
|
| 129 |
+
The total number of training steps.
|
| 130 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 131 |
+
The index of the last epoch when resuming training.
|
| 132 |
+
|
| 133 |
+
Return:
|
| 134 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def lr_lambda(current_step: int):
|
| 138 |
+
if current_step < num_warmup_steps:
|
| 139 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 140 |
+
return max(
|
| 141 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def get_cosine_schedule_with_warmup(
|
| 148 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5,
|
| 149 |
+
last_epoch: int = -1
|
| 150 |
+
):
|
| 151 |
+
"""
|
| 152 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 153 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
| 154 |
+
initial lr set in the optimizer.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 158 |
+
The optimizer for which to schedule the learning rate.
|
| 159 |
+
num_warmup_steps (`int`):
|
| 160 |
+
The number of steps for the warmup phase.
|
| 161 |
+
num_training_steps (`int`):
|
| 162 |
+
The total number of training steps.
|
| 163 |
+
num_periods (`float`, *optional*, defaults to 0.5):
|
| 164 |
+
The number of periods of the cosine function in a schedule (the default is to just decrease from the max
|
| 165 |
+
value to 0 following a half-cosine).
|
| 166 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 167 |
+
The index of the last epoch when resuming training.
|
| 168 |
+
|
| 169 |
+
Return:
|
| 170 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def lr_lambda(current_step):
|
| 174 |
+
if current_step < num_warmup_steps:
|
| 175 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 176 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 177 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
| 178 |
+
|
| 179 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
| 183 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
| 184 |
+
):
|
| 185 |
+
"""
|
| 186 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 187 |
+
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
| 188 |
+
linearly between 0 and the initial lr set in the optimizer.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 192 |
+
The optimizer for which to schedule the learning rate.
|
| 193 |
+
num_warmup_steps (`int`):
|
| 194 |
+
The number of steps for the warmup phase.
|
| 195 |
+
num_training_steps (`int`):
|
| 196 |
+
The total number of training steps.
|
| 197 |
+
num_cycles (`int`, *optional*, defaults to 1):
|
| 198 |
+
The number of hard restarts to use.
|
| 199 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 200 |
+
The index of the last epoch when resuming training.
|
| 201 |
+
|
| 202 |
+
Return:
|
| 203 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def lr_lambda(current_step):
|
| 207 |
+
if current_step < num_warmup_steps:
|
| 208 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 209 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 210 |
+
if progress >= 1.0:
|
| 211 |
+
return 0.0
|
| 212 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
|
| 213 |
+
|
| 214 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def get_polynomial_decay_schedule_with_warmup(
|
| 218 |
+
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
|
| 219 |
+
):
|
| 220 |
+
"""
|
| 221 |
+
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
|
| 222 |
+
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
|
| 223 |
+
initial lr set in the optimizer.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 227 |
+
The optimizer for which to schedule the learning rate.
|
| 228 |
+
num_warmup_steps (`int`):
|
| 229 |
+
The number of steps for the warmup phase.
|
| 230 |
+
num_training_steps (`int`):
|
| 231 |
+
The total number of training steps.
|
| 232 |
+
lr_end (`float`, *optional*, defaults to 1e-7):
|
| 233 |
+
The end LR.
|
| 234 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 235 |
+
Power factor.
|
| 236 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 237 |
+
The index of the last epoch when resuming training.
|
| 238 |
+
|
| 239 |
+
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
|
| 240 |
+
implementation at
|
| 241 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
|
| 242 |
+
|
| 243 |
+
Return:
|
| 244 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 245 |
+
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
lr_init = optimizer.defaults["lr"]
|
| 249 |
+
if not (lr_init > lr_end):
|
| 250 |
+
raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
|
| 251 |
+
|
| 252 |
+
def lr_lambda(current_step: int):
|
| 253 |
+
if current_step < num_warmup_steps:
|
| 254 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 255 |
+
elif current_step > num_training_steps:
|
| 256 |
+
return lr_end / lr_init # as LambdaLR multiplies by lr_init
|
| 257 |
+
else:
|
| 258 |
+
lr_range = lr_init - lr_end
|
| 259 |
+
decay_steps = num_training_steps - num_warmup_steps
|
| 260 |
+
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
|
| 261 |
+
decay = lr_range * pct_remaining ** power + lr_end
|
| 262 |
+
return decay / lr_init # as LambdaLR multiplies by lr_init
|
| 263 |
+
|
| 264 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
TYPE_TO_SCHEDULER_FUNCTION = {
|
| 268 |
+
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
| 269 |
+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
| 270 |
+
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
|
| 271 |
+
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
| 272 |
+
SchedulerType.CONSTANT: get_constant_schedule,
|
| 273 |
+
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
| 274 |
+
SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def get_scheduler(
|
| 279 |
+
name: Union[str, SchedulerType],
|
| 280 |
+
optimizer: Optimizer,
|
| 281 |
+
step_rules: Optional[str] = None,
|
| 282 |
+
num_warmup_steps: Optional[int] = None,
|
| 283 |
+
num_training_steps: Optional[int] = None,
|
| 284 |
+
num_cycles: int = 1,
|
| 285 |
+
power: float = 1.0,
|
| 286 |
+
last_epoch: int = -1,
|
| 287 |
+
):
|
| 288 |
+
"""
|
| 289 |
+
Unified API to get any scheduler from its name.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
name (`str` or `SchedulerType`):
|
| 293 |
+
The name of the scheduler to use.
|
| 294 |
+
optimizer (`torch.optim.Optimizer`):
|
| 295 |
+
The optimizer that will be used during training.
|
| 296 |
+
step_rules (`str`, *optional*):
|
| 297 |
+
A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
|
| 298 |
+
num_warmup_steps (`int`, *optional*):
|
| 299 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
| 300 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 301 |
+
num_training_steps (`int``, *optional*):
|
| 302 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
| 303 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 304 |
+
num_cycles (`int`, *optional*):
|
| 305 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
| 306 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 307 |
+
Power factor. See `POLYNOMIAL` scheduler
|
| 308 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 309 |
+
The index of the last epoch when resuming training.
|
| 310 |
+
"""
|
| 311 |
+
name = SchedulerType(name)
|
| 312 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
| 313 |
+
if name == SchedulerType.CONSTANT:
|
| 314 |
+
return schedule_func(optimizer, last_epoch=last_epoch)
|
| 315 |
+
|
| 316 |
+
if name == SchedulerType.PIECEWISE_CONSTANT:
|
| 317 |
+
return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
|
| 318 |
+
|
| 319 |
+
# All other schedulers require `num_warmup_steps`
|
| 320 |
+
if num_warmup_steps is None:
|
| 321 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
| 322 |
+
|
| 323 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
| 324 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
|
| 325 |
+
|
| 326 |
+
# All other schedulers require `num_training_steps`
|
| 327 |
+
if num_training_steps is None:
|
| 328 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
| 329 |
+
|
| 330 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
| 331 |
+
return schedule_func(
|
| 332 |
+
optimizer,
|
| 333 |
+
num_warmup_steps=num_warmup_steps,
|
| 334 |
+
num_training_steps=num_training_steps,
|
| 335 |
+
num_cycles=num_cycles,
|
| 336 |
+
last_epoch=last_epoch,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if name == SchedulerType.POLYNOMIAL:
|
| 340 |
+
return schedule_func(
|
| 341 |
+
optimizer,
|
| 342 |
+
num_warmup_steps=num_warmup_steps,
|
| 343 |
+
num_training_steps=num_training_steps,
|
| 344 |
+
power=power,
|
| 345 |
+
last_epoch=last_epoch,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
return schedule_func(
|
| 349 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
|
| 350 |
+
)
|