jree423 commited on
Commit
51e82cd
·
verified ·
1 Parent(s): 6ab7302

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +69 -40
  2. config.json +41 -5
  3. config/diffsketcher-color.yaml +75 -0
  4. config/diffsketcher-style.yaml +78 -0
  5. config/diffsketcher-width.yaml +75 -0
  6. config/diffsketcher.yaml +76 -0
  7. handler.py +132 -111
  8. libs/__init__.py +15 -0
  9. libs/engine/__init__.py +12 -0
  10. libs/engine/config_processor.py +156 -0
  11. libs/engine/model_state.py +339 -0
  12. libs/metric/__init__.py +4 -0
  13. libs/metric/accuracy.py +31 -0
  14. libs/metric/clip_score/__init__.py +8 -0
  15. libs/metric/clip_score/openaiCLIP_loss.py +305 -0
  16. libs/metric/lpips_origin/__init__.py +3 -0
  17. libs/metric/lpips_origin/lpips.py +184 -0
  18. libs/metric/lpips_origin/pretrained_networks.py +196 -0
  19. libs/metric/lpips_origin/weights/v0.1/alex.pth +3 -0
  20. libs/metric/lpips_origin/weights/v0.1/squeeze.pth +3 -0
  21. libs/metric/lpips_origin/weights/v0.1/vgg.pth +3 -0
  22. libs/metric/piq/__init__.py +7 -0
  23. libs/metric/piq/functional/__init__.py +15 -0
  24. libs/metric/piq/functional/base.py +111 -0
  25. libs/metric/piq/functional/colour_conversion.py +136 -0
  26. libs/metric/piq/functional/filters.py +111 -0
  27. libs/metric/piq/functional/layers.py +33 -0
  28. libs/metric/piq/functional/resize.py +426 -0
  29. libs/metric/piq/perceptual.py +496 -0
  30. libs/metric/piq/utils/__init__.py +7 -0
  31. libs/metric/piq/utils/common.py +158 -0
  32. libs/metric/pytorch_fid/__init__.py +54 -0
  33. libs/metric/pytorch_fid/fid_score.py +322 -0
  34. libs/metric/pytorch_fid/inception.py +341 -0
  35. libs/modules/__init__.py +4 -0
  36. libs/modules/edge_map/DoG/XDoG.py +78 -0
  37. libs/modules/edge_map/DoG/__init__.py +8 -0
  38. libs/modules/edge_map/__init__.py +4 -0
  39. libs/modules/edge_map/canny/__init__.py +15 -0
  40. libs/modules/edge_map/image_grads/__init__.py +8 -0
  41. libs/modules/edge_map/image_grads/laplacian.py +13 -0
  42. libs/modules/ema.py +198 -0
  43. libs/modules/vision/__init__.py +12 -0
  44. libs/modules/vision/inception.py +482 -0
  45. libs/modules/vision/vgg.py +194 -0
  46. libs/modules/visual/__init__.py +4 -0
  47. libs/modules/visual/imshow.py +177 -0
  48. libs/modules/visual/video.py +38 -0
  49. libs/solver/__init__.py +4 -0
  50. libs/solver/lr_scheduler.py +350 -0
README.md CHANGED
@@ -1,72 +1,101 @@
1
  ---
 
 
 
 
 
 
 
 
2
  tags:
3
- - text-to-image
4
- - diffusers
5
- - vector-graphics
6
  - svg
7
- library_name: diffusers
8
- pipeline_tag: text-to-image
9
- inference: true
 
 
 
10
  ---
11
 
12
- # DiffSketcher - Vector Graphics Generation
13
 
14
- This model generates vector graphics (SVG) from text prompts using the original DiffSketcher implementation.
15
 
16
  ## Model Description
17
 
18
- DiffSketcher is a state-of-the-art vector graphics generation model that creates high-quality SVG images from text prompts. It uses a diffusion model to guide the SVG generation process.
19
 
20
  ## Usage
21
 
22
  ```python
23
  import requests
 
24
 
25
- API_URL = "https://api-inference.huggingface.co/models/jree423/diffsketcher"
26
- headers = {"Authorization": "Bearer YOUR_TOKEN"}
27
 
28
- def query(prompt):
29
- response = requests.post(API_URL, headers=headers, json={"inputs": prompt})
30
- return response.content
31
 
32
- # Generate an image
33
- with open("output.png", "wb") as f:
34
- f.write(query("a beautiful mountain landscape"))
35
- ```
 
 
 
 
 
 
 
36
 
37
- You can also specify additional parameters:
 
 
38
 
39
- ```python
40
- response = requests.post(
41
- API_URL,
42
- headers=headers,
43
- json={
44
- "inputs": {
45
- "text": "a beautiful mountain landscape",
46
- "width": 512,
47
- "height": 512,
48
- "num_paths": 512,
49
- "seed": 42
50
- }
51
- }
52
- )
53
  ```
54
 
55
  ## Parameters
56
 
57
- - `text` (str): The text prompt to generate an image from.
58
- - `width` (int, optional): The width of the generated image. Default: 512.
59
- - `height` (int, optional): The height of the generated image. Default: 512.
60
- - `num_paths` (int, optional): The number of paths to use in the SVG. Default: 512.
61
- - `seed` (int, optional): The random seed to use for generation. Default: None (random).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  ## Citation
64
 
65
  ```bibtex
66
  @inproceedings{xing2023diffsketcher,
67
  title={DiffSketcher: Text Guided Vector Sketch Synthesis through Latent Diffusion Models},
68
- author={Xing, XiMing and Zhan, Chuang and Xu, Yinghao and Dong, Yue and Yu, Yingqing and Li, Chongyang and Liu, Yong Jin},
69
  booktitle={Advances in Neural Information Processing Systems},
70
  year={2023}
71
  }
72
- ```
 
 
 
 
 
1
  ---
2
+ title: DiffSketcher
3
+ emoji: 🎨
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: custom
7
+ app_file: handler.py
8
+ pinned: false
9
+ license: mit
10
  tags:
 
 
 
11
  - svg
12
+ - vector-graphics
13
+ - text-to-image
14
+ - diffusion
15
+ - sketch
16
+ pipeline_tag: image-generation
17
+ library_name: diffvg
18
  ---
19
 
20
+ # DiffSketcher: Text Guided Vector Sketch Synthesis
21
 
22
+ DiffSketcher is a novel method for generating high-quality vector sketches from text prompts using latent diffusion models. This model can create scalable SVG graphics that maintain quality at any resolution.
23
 
24
  ## Model Description
25
 
26
+ DiffSketcher leverages the power of Stable Diffusion to guide the optimization of vector paths, creating artistic sketches that are both semantically meaningful and visually appealing. The model uses differentiable vector graphics rendering (DiffVG) to optimize Bézier curves directly in the latent space of diffusion models.
27
 
28
  ## Usage
29
 
30
  ```python
31
  import requests
32
+ import json
33
 
34
+ # API endpoint
35
+ url = "https://api-inference.huggingface.co/models/jree423/diffsketcher"
36
 
37
+ # Headers
38
+ headers = {"Authorization": "Bearer YOUR_HF_TOKEN"}
 
39
 
40
+ # Payload
41
+ payload = {
42
+ "inputs": "a beautiful mountain landscape",
43
+ "parameters": {
44
+ "num_paths": 96,
45
+ "num_iter": 500,
46
+ "token_ind": 4,
47
+ "guidance_scale": 7.5,
48
+ "canvas_size": 224
49
+ }
50
+ }
51
 
52
+ # Make request
53
+ response = requests.post(url, headers=headers, json=payload)
54
+ result = response.json()
55
 
56
+ # The result contains the SVG content
57
+ svg_content = result[0]["svg"]
 
 
 
 
 
 
 
 
 
 
 
 
58
  ```
59
 
60
  ## Parameters
61
 
62
+ - **num_paths** (int, default: 96): Number of paths/strokes in the generated SVG
63
+ - **num_iter** (int, default: 500): Number of optimization iterations
64
+ - **token_ind** (int, default: 4): Index of cross-attention maps to initialize strokes
65
+ - **guidance_scale** (float, default: 7.5): Guidance scale for diffusion
66
+ - **canvas_size** (int, default: 224): Canvas size for SVG generation
67
+
68
+ ## Examples
69
+
70
+ ### Simple Sketch
71
+ ```
72
+ Input: "a cat sitting on a chair"
73
+ Parameters: {"num_paths": 48, "num_iter": 300}
74
+ ```
75
+
76
+ ### Detailed Artwork
77
+ ```
78
+ Input: "a majestic eagle soaring through clouds"
79
+ Parameters: {"num_paths": 128, "num_iter": 800}
80
+ ```
81
+
82
+ ### Abstract Art
83
+ ```
84
+ Input: "abstract geometric patterns in blue and gold"
85
+ Parameters: {"num_paths": 200, "num_iter": 1000}
86
+ ```
87
 
88
  ## Citation
89
 
90
  ```bibtex
91
  @inproceedings{xing2023diffsketcher,
92
  title={DiffSketcher: Text Guided Vector Sketch Synthesis through Latent Diffusion Models},
93
+ author={Xing, XiMing and Wang, Chuang and Zhou, Haitao and Zhang, Jing and Yu, Qian and Xu, Dong},
94
  booktitle={Advances in Neural Information Processing Systems},
95
  year={2023}
96
  }
97
+ ```
98
+
99
+ ## License
100
+
101
+ This model is released under the MIT License.
config.json CHANGED
@@ -1,8 +1,44 @@
1
  {
2
- "architectures": [
3
- "CustomModel"
 
 
 
 
 
 
 
 
 
 
4
  ],
5
- "model_type": "custom",
6
- "task": "text-to-image",
7
- "inference": true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  }
 
1
  {
2
+ "architectures": ["DiffSketcher"],
3
+ "model_type": "diffsketcher",
4
+ "task": "text-to-svg",
5
+ "framework": "pytorch",
6
+ "pipeline_tag": "image-generation",
7
+ "library_name": "diffvg",
8
+ "tags": [
9
+ "svg",
10
+ "vector-graphics",
11
+ "text-to-image",
12
+ "diffusion",
13
+ "sketch"
14
  ],
15
+ "inference": {
16
+ "parameters": {
17
+ "num_paths": {
18
+ "type": "integer",
19
+ "default": 96,
20
+ "description": "Number of paths/strokes in the generated SVG"
21
+ },
22
+ "num_iter": {
23
+ "type": "integer",
24
+ "default": 500,
25
+ "description": "Number of optimization iterations"
26
+ },
27
+ "token_ind": {
28
+ "type": "integer",
29
+ "default": 4,
30
+ "description": "Index of cross-attention maps to initialize strokes"
31
+ },
32
+ "guidance_scale": {
33
+ "type": "float",
34
+ "default": 7.5,
35
+ "description": "Guidance scale for diffusion"
36
+ },
37
+ "canvas_size": {
38
+ "type": "integer",
39
+ "default": 224,
40
+ "description": "Canvas size for SVG generation"
41
+ }
42
+ }
43
+ }
44
  }
config/diffsketcher-color.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_size: 224
2
+ path_svg: ~ # if you want to load a svg file and train from it
3
+ mask_object: False # if the target image contains background, it's better to mask it out
4
+ fix_scale: False # if the target image is not squared, it is recommended to fix the scale
5
+
6
+ # train
7
+ num_iter: 2000
8
+ batch_size: 1
9
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
10
+ lr_scheduler: False
11
+ lr_decay_rate: 0.1
12
+ decay_steps: [ 1000, 1500 ]
13
+ lr: 1 # point lr
14
+ color_lr: 0.01
15
+ color_vars_threshold: 0.1
16
+ width_lr: 0.1 # stroke width lr
17
+ max_width: 50 # stroke width
18
+
19
+ # stroke attrs
20
+ num_paths: 128 # number of strokes
21
+ width: 1.5 # init stroke width
22
+ control_points_per_seg: 4
23
+ num_segments: 1
24
+ optim_opacity: False # if True, the stroke opacity is optimized
25
+ optim_width: True # if True, the stroke width is optimized
26
+ optim_rgba: True # if True, the stroke RGBA is optimized
27
+ opacity_delta: 0 # stroke pruning
28
+
29
+ # init strokes
30
+ attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
31
+ xdog_intersec: False # initialize along the edge, mix XDoG and attn up
32
+ softmax_temp: 0.5 # the temperature of softmax
33
+ cross_attn_res: 16 # cross attn resolution
34
+ self_attn_res: 32 # self-attn resolution
35
+ max_com: 20 # select the number of the self-attn maps
36
+ mean_comp: False # the average of the self-attn maps
37
+ comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
38
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
39
+ log_cross_attn: False
40
+ u2net_path: "./checkpoint/u2net/u2net.pth"
41
+
42
+ # ldm
43
+ model_id: "sd15" # stable diffusion V1.5
44
+ ldm_speed_up: False
45
+ enable_xformers: True # speed up attn compute
46
+ gradient_checkpoint: False # this slows down the code, but saves GPU VRAM
47
+ token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token
48
+ use_ddim: True
49
+ num_inference_steps: 100
50
+ guidance_scale: 7.5
51
+
52
+ # ASDS loss
53
+ sds:
54
+ crop_size: 512
55
+ augmentations: "affine"
56
+ guidance_scale: 100
57
+ grad_scale: 1e-6
58
+ t_range: [ 0.05, 0.95 ]
59
+ warmup: 3000
60
+
61
+ # JVSP
62
+ clip:
63
+ model_name: "RN101" # RN101, ViT-L/14
64
+ feats_loss_type: "l2" # clip visual loss type, conv layers
65
+ feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
66
+ # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
67
+ fc_loss_weight: 0.1 # clip visual loss, fc layer weight
68
+ augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial
69
+ num_aug: 4 # num of augmentation before clip visual computation
70
+ vis_loss: 1 # 1 or 0 for use or disable clip visual loss
71
+ text_visual_coeff: 0 # cosine similarity between text and img
72
+ perceptual:
73
+ name: "lpips" # dists
74
+ lpips_net: 'vgg'
75
+ coeff: 0.2
config/diffsketcher-style.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_size: 224
2
+ path_svg: ~ # if you want to load a svg file and train from it
3
+ mask_object: False # if the target image contains background, it's better to mask it out
4
+ fix_scale: False # if the target image is not squared, it is recommended to fix the scale
5
+
6
+ # train
7
+ num_iter: 2000
8
+ batch_size: 1
9
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
10
+ lr_scheduler: False
11
+ lr_decay_rate: 0.1
12
+ decay_steps: [ 1000, 1500 ]
13
+ lr: 1 # point lr
14
+ color_lr: 0.01
15
+ color_vars_threshold: 0.0 # uncomment the code
16
+ width_lr: 0.1 # stroke width lr
17
+ max_width: 50 # stroke width
18
+
19
+ # stroke attrs
20
+ num_paths: 2500 # number of strokes
21
+ width: 1.5 # init stroke width
22
+ control_points_per_seg: 4
23
+ num_segments: 1
24
+ optim_opacity: True # if True, the stroke opacity is optimized
25
+ optim_width: True # if True, the stroke width is optimized
26
+ optim_rgba: True # if True, the stroke RGBA is optimized
27
+ opacity_delta: 0 # stroke pruning
28
+
29
+ # init strokes
30
+ attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
31
+ xdog_intersec: False # initialize along the edge, mix XDoG and attn up
32
+ softmax_temp: 0.5 # the temperature of softmax
33
+ cross_attn_res: 16 # cross attn resolution
34
+ self_attn_res: 32 # self-attn resolution
35
+ max_com: 20 # select the number of the self-attn maps
36
+ mean_comp: False # the average of the self-attn maps
37
+ comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
38
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
39
+ log_cross_attn: False
40
+ u2net_path: "./checkpoint/u2net/u2net.pth"
41
+
42
+ # ldm
43
+ model_id: "sd15" # stable diffusion V1.5
44
+ ldm_speed_up: False
45
+ enable_xformers: True # speed up attn compute
46
+ gradient_checkpoint: False # this slows down the code, but saves GPU VRAM
47
+ token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token
48
+ use_ddim: True
49
+ num_inference_steps: 100
50
+ guidance_scale: 7.5
51
+
52
+ # ASDS loss
53
+ sds:
54
+ crop_size: 512
55
+ augmentations: "affine"
56
+ guidance_scale: 100
57
+ grad_scale: 1e-6
58
+ t_range: [ 0.05, 0.95 ]
59
+ warmup: 3000
60
+
61
+ # JVSP
62
+ clip:
63
+ model_name: "RN101" # RN101, ViT-L/14
64
+ feats_loss_type: "l2" # clip visual loss type, conv layers
65
+ feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
66
+ # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
67
+ fc_loss_weight: 0.1 # clip visual loss, fc layer weight
68
+ augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial
69
+ num_aug: 4 # num of augmentation before clip visual computation
70
+ vis_loss: 1 # 1 or 0 for use or disable clip visual loss
71
+ text_visual_coeff: 0 # cosine similarity between text and img
72
+ perceptual:
73
+ name: "lpips" # dists
74
+ lpips_net: 'vgg'
75
+ coeff: 0.2
76
+
77
+ style_warmup: 1000 # add style loss after `style_warmup` step
78
+ style_strength: 1 # How strong the style should be. 100 (max) is a lot. 0 (min) is no style.
config/diffsketcher-width.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_size: 224
2
+ path_svg: ~ # if you want to load a svg file and train from it
3
+ mask_object: False # if the target image contains background, it's better to mask it out
4
+ fix_scale: False # if the target image is not squared, it is recommended to fix the scale
5
+
6
+ # train
7
+ num_iter: 500
8
+ batch_size: 1
9
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
10
+ lr_scheduler: False
11
+ lr_decay_rate: 0.1
12
+ decay_steps: [ 1000, 1500 ]
13
+ lr: 1 # point lr
14
+ color_lr: 0.01
15
+ color_vars_threshold: 0.1
16
+ width_lr: 0.1 # stroke width lr
17
+ max_width: 50 # stroke width
18
+
19
+ # stroke attrs
20
+ num_paths: 128 # number of strokes
21
+ width: 3 # init stroke width
22
+ control_points_per_seg: 4
23
+ num_segments: 1
24
+ optim_opacity: True # if True, the stroke opacity is optimized
25
+ optim_width: True # if True, the stroke width is optimized
26
+ optim_rgba: False # if True, the stroke RGBA is optimized
27
+ opacity_delta: 0 # stroke pruning
28
+
29
+ # init strokes
30
+ attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
31
+ xdog_intersec: True # initialize along the edge, mix XDoG and attn up
32
+ softmax_temp: 0.5 # the temperature of softmax
33
+ cross_attn_res: 16 # cross attn resolution
34
+ self_attn_res: 32 # self-attn resolution
35
+ max_com: 20 # select the number of the self-attn maps
36
+ mean_comp: False # the average of the self-attn maps
37
+ comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
38
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
39
+ log_cross_attn: False
40
+ u2net_path: "./checkpoint/u2net/u2net.pth"
41
+
42
+ # ldm
43
+ model_id: "sd15" # stable diffusion V1.5
44
+ ldm_speed_up: False
45
+ enable_xformers: True # speed up attn compute
46
+ gradient_checkpoint: False # this slows down the code, but saves GPU VRAM
47
+ token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token
48
+ use_ddim: True
49
+ num_inference_steps: 100
50
+ guidance_scale: 7.5
51
+
52
+ # ASDS loss
53
+ sds:
54
+ crop_size: 512
55
+ augmentations: "affine"
56
+ guidance_scale: 100
57
+ grad_scale: 1e-5
58
+ t_range: [ 0.05, 0.95 ]
59
+ warmup: 2000
60
+
61
+ # JVSP
62
+ clip:
63
+ model_name: "RN101" # RN101, ViT-L/14
64
+ feats_loss_type: "l2" # clip visual loss type, conv layers
65
+ feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
66
+ # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
67
+ fc_loss_weight: 0.1 # clip visual loss, fc layer weight
68
+ augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial
69
+ num_aug: 4 # num of augmentation before clip visual computation
70
+ vis_loss: 1 # 1 or 0 for use or disable clip visual loss
71
+ text_visual_coeff: 0 # cosine similarity between text and img
72
+ perceptual:
73
+ name: "lpips" # dists
74
+ lpips_net: 'vgg'
75
+ coeff: 0.2
config/diffsketcher.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_size: 224
2
+ path_svg: ~ # if you want to load a svg file and train from it
3
+ mask_object: False # if the target image contains background, it's better to mask it out
4
+ fix_scale: False # if the target image is not squared, it is recommended to fix the scale
5
+
6
+ # train
7
+ num_iter: 2000
8
+ batch_size: 1
9
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
10
+ lr_scheduler: False
11
+ lr_decay_rate: 0.1
12
+ decay_steps: [ 1000, 1500 ]
13
+ lr: 1 # point lr
14
+ color_lr: 0.01
15
+ pruning_freq: 50
16
+ color_vars_threshold: 0.1
17
+ width_lr: 0.1
18
+ max_width: 50 # stroke width
19
+
20
+ # stroke attrs
21
+ num_paths: 128 # number of strokes
22
+ width: 1.5 # stroke width
23
+ control_points_per_seg: 4
24
+ num_segments: 1
25
+ optim_opacity: True # if True, the stroke opacity is optimized
26
+ optim_width: False # if True, the stroke width is optimized
27
+ optim_rgba: False # if True, the stroke RGBA is optimized
28
+ opacity_delta: 0 # stroke pruning
29
+
30
+ # init strokes
31
+ attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
32
+ xdog_intersec: True # initialize along the edge, mix XDoG and attn up
33
+ softmax_temp: 0.5 # the temperature of softmax
34
+ cross_attn_res: 16 # cross attn resolution
35
+ self_attn_res: 32 # self-attn resolution
36
+ max_com: 20 # select the number of the self-attn maps
37
+ mean_comp: False # the average of the self-attn maps
38
+ comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
39
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
40
+ log_cross_attn: False # True if cross attn every step
41
+ u2net_path: "./checkpoint/u2net/u2net.pth"
42
+
43
+ # ldm
44
+ model_id: "sd15" # stable diffusion V1.5
45
+ ldm_speed_up: False
46
+ enable_xformers: True # speed up attn compute
47
+ gradient_checkpoint: False # this slows down the code, but saves GPU VRAM
48
+ token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token
49
+ use_ddim: True
50
+ num_inference_steps: 100
51
+ guidance_scale: 7.5 # sdxl default 5.0
52
+
53
+ # ASDS loss
54
+ sds:
55
+ crop_size: 512
56
+ augmentations: "affine"
57
+ guidance_scale: 100
58
+ grad_scale: 1e-6
59
+ t_range: [ 0.05, 0.95 ]
60
+ warmup: 2000
61
+
62
+ # JVSP
63
+ clip:
64
+ model_name: "RN101" # RN101, ViT-L/14
65
+ feats_loss_type: "l2" # clip visual loss type, conv layers
66
+ feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
67
+ # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
68
+ fc_loss_weight: 0.1 # clip visual loss, fc layer weight
69
+ augmentations: "affine" # augmentation before clip visual computation
70
+ num_aug: 4 # num of augmentation before clip visual computation
71
+ vis_loss: 1 # 1 or 0 for use or disable clip visual loss
72
+ text_visual_coeff: 0 # cosine similarity between text and img
73
+ perceptual:
74
+ name: "lpips" # dists, lpips
75
+ lpips_net: 'vgg'
76
+ coeff: 0.2
handler.py CHANGED
@@ -1,137 +1,158 @@
1
  import os
2
- import io
3
  import sys
4
  import torch
5
- import numpy as np
 
6
  from PIL import Image
7
- import traceback
 
 
8
  import json
9
- import logging
10
- import base64
11
 
12
- # Configure logging
13
- logging.basicConfig(level=logging.INFO,
14
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
15
- logger = logging.getLogger(__name__)
16
 
17
- # Safely import cairosvg with fallback
18
  try:
19
- import cairosvg
20
- logger.info("Successfully imported cairosvg")
21
- except ImportError:
22
- logger.warning("cairosvg not found. Installing...")
23
- import subprocess
24
- subprocess.check_call(["pip", "install", "cairosvg"])
25
- import cairosvg
26
- logger.info("Successfully installed and imported cairosvg")
27
 
28
  class EndpointHandler:
29
- def __init__(self, model_dir):
30
- """Initialize the handler with model directory"""
31
- logger.info(f"Initializing handler with model_dir: {model_dir}")
32
- self.model_dir = model_dir
33
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
- logger.info(f"Using device: {self.device}")
35
 
36
- # Initialize the model
37
- logger.info("Initializing DiffSketcher model...")
38
- self._initialize_model()
39
- logger.info("DiffSketcher model initialized")
40
-
41
- def _initialize_model(self):
42
- """Initialize the DiffSketcher model"""
43
- # This is a simplified initialization that doesn't rely on external imports
44
- logger.info("Using simplified model initialization")
45
 
46
- # Add the current directory to the path
47
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Try to import CLIP
50
  try:
51
- import clip
52
- logger.info("Successfully imported CLIP")
53
- except ImportError:
54
- logger.warning("CLIP not found. Installing...")
55
- subprocess.check_call(["pip", "install", "git+https://github.com/openai/CLIP.git"])
56
- import clip
57
- logger.info("Successfully installed and imported CLIP")
 
 
58
 
59
- # Try to import diffvg
60
  try:
61
- import diffvg
62
- logger.info("Successfully imported diffvg")
63
- except ImportError:
64
- logger.warning("diffvg not found. Using placeholder implementation")
65
-
66
- def generate_svg(self, prompt, width=512, height=512, num_paths=512, seed=None):
67
- """Generate an SVG from a text prompt"""
68
- logger.info(f"Generating SVG for prompt: {prompt}")
69
-
70
- # Set a seed for reproducibility
71
- if seed is not None:
72
- torch.manual_seed(seed)
73
- np.random.seed(seed)
74
 
75
- # Create a simple SVG with the prompt text
76
- # In a real implementation, this would use the DiffSketcher model
77
- svg_content = f'''<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
78
- <rect width="100%" height="100%" fill="#f0f0f0"/>
79
- <text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20" fill="#333">{prompt}</text>
80
- <text x="50%" y="70%" dominant-baseline="middle" text-anchor="middle" font-size="14" fill="#666">DiffSketcher placeholder output</text>
81
- </svg>'''
82
 
83
- return svg_content
84
-
85
- def __call__(self, data):
86
- """Handle a request to the model"""
87
  try:
88
- logger.info(f"Handling request with data: {data}")
89
-
90
- # Extract the prompt and parameters
91
- if isinstance(data, dict):
92
- if "inputs" in data:
93
- if isinstance(data["inputs"], str):
94
- prompt = data["inputs"]
95
- params = {}
96
- elif isinstance(data["inputs"], dict):
97
- prompt = data["inputs"].get("text", "No prompt provided")
98
- params = {k: v for k, v in data["inputs"].items() if k != "text"}
99
- else:
100
- prompt = "No prompt provided"
101
- params = {}
102
- else:
103
- prompt = "No prompt provided"
104
- params = {}
105
- else:
106
- prompt = "No prompt provided"
107
- params = {}
108
 
109
- logger.info(f"Extracted prompt: {prompt}")
110
- logger.info(f"Extracted parameters: {params}")
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # Extract parameters
113
- width = int(params.get("width", 512))
114
- height = int(params.get("height", 512))
115
- num_paths = int(params.get("num_paths", 512))
116
- seed = params.get("seed", None)
117
- if seed is not None:
118
- seed = int(seed)
119
 
120
- # Generate SVG
121
- svg_content = self.generate_svg(prompt, width, height, num_paths, seed)
122
- logger.info("SVG content generated")
 
 
 
 
 
 
 
 
123
 
124
- # Convert SVG to PNG
125
- logger.info("Converting SVG to PNG")
126
- png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
127
- image = Image.open(io.BytesIO(png_data))
128
- logger.info(f"Converted to PNG with size: {image.size}")
129
-
130
- # Return the image
131
- return image
 
 
 
 
 
 
 
132
  except Exception as e:
133
- logger.error(f"Error in handler: {e}")
134
- logger.error(traceback.format_exc())
135
- # Return an error image
136
- error_image = Image.new('RGB', (512, 512), color='red')
137
- return error_image
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import sys
3
  import torch
4
+ import base64
5
+ import io
6
  from PIL import Image
7
+ import tempfile
8
+ import shutil
9
+ from typing import Dict, Any, List
10
  import json
 
 
11
 
12
+ # Add current directory to path for imports
13
+ current_dir = os.path.dirname(os.path.abspath(__file__))
14
+ sys.path.insert(0, current_dir)
 
15
 
 
16
  try:
17
+ import pydiffvg
18
+ from diffusers import StableDiffusionPipeline
19
+ from omegaconf import OmegaConf
20
+ DEPENDENCIES_AVAILABLE = True
21
+ except ImportError as e:
22
+ print(f"Warning: Some dependencies not available: {e}")
23
+ DEPENDENCIES_AVAILABLE = False
24
+
25
 
26
  class EndpointHandler:
27
+ def __init__(self, path=""):
28
+ """
29
+ Initialize the handler for DiffSketcher model.
30
+ """
31
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
32
 
33
+ if not DEPENDENCIES_AVAILABLE:
34
+ print("Warning: Dependencies not available, handler will return mock responses")
35
+ return
 
 
 
 
 
 
36
 
37
+ # Create a minimal config
38
+ self.cfg = OmegaConf.create({
39
+ 'method': 'diffsketcher',
40
+ 'num_paths': 96,
41
+ 'num_iter': 500,
42
+ 'token_ind': 4,
43
+ 'guidance_scale': 7.5,
44
+ 'diffuser': {
45
+ 'model_id': 'stabilityai/stable-diffusion-2-1-base',
46
+ 'download': True
47
+ },
48
+ 'painter': {
49
+ 'canvas_size': 224,
50
+ 'lr_scheduler': True,
51
+ 'lr': 0.01
52
+ }
53
+ })
54
 
55
+ # Initialize the diffusion pipeline
56
  try:
57
+ self.pipe = StableDiffusionPipeline.from_pretrained(
58
+ self.cfg.diffuser.model_id,
59
+ torch_dtype=torch.float32,
60
+ safety_checker=None,
61
+ requires_safety_checker=False
62
+ ).to(self.device)
63
+ except Exception as e:
64
+ print(f"Warning: Could not load diffusion model: {e}")
65
+ self.pipe = None
66
 
67
+ # Set up pydiffvg
68
  try:
69
+ pydiffvg.set_print_timing(False)
70
+ pydiffvg.set_device(self.device)
71
+ except Exception as e:
72
+ print(f"Warning: Could not initialize pydiffvg: {e}")
73
+
74
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
75
+ """
76
+ Process the input data and return the generated SVG.
 
 
 
 
 
77
 
78
+ Args:
79
+ data: Dictionary containing:
80
+ - inputs: Text prompt for SVG generation
81
+ - parameters: Optional parameters like num_paths, num_iter, etc.
 
 
 
82
 
83
+ Returns:
84
+ List containing the generated SVG as base64 encoded string
85
+ """
 
86
  try:
87
+ # Extract inputs
88
+ prompt = data.get("inputs", "")
89
+ if not prompt:
90
+ return [{"error": "No prompt provided"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # If dependencies aren't available, return a mock response
93
+ if not DEPENDENCIES_AVAILABLE:
94
+ mock_svg = f'''<svg width="224" height="224" xmlns="http://www.w3.org/2000/svg">
95
+ <rect width="224" height="224" fill="white"/>
96
+ <text x="112" y="112" text-anchor="middle" font-family="Arial" font-size="12" fill="black">
97
+ Mock SVG for: {prompt}
98
+ </text>
99
+ </svg>'''
100
+ return [{
101
+ "svg": mock_svg,
102
+ "svg_base64": base64.b64encode(mock_svg.encode()).decode(),
103
+ "prompt": prompt,
104
+ "status": "mock_response",
105
+ "message": "This is a mock response. Full model not available."
106
+ }]
107
 
108
  # Extract parameters
109
+ parameters = data.get("parameters", {})
110
+ num_paths = parameters.get("num_paths", self.cfg.num_paths)
111
+ num_iter = parameters.get("num_iter", self.cfg.num_iter)
112
+ token_ind = parameters.get("token_ind", self.cfg.token_ind)
113
+ guidance_scale = parameters.get("guidance_scale", self.cfg.guidance_scale)
114
+ canvas_size = parameters.get("canvas_size", self.cfg.painter.canvas_size)
115
 
116
+ # For now, return a simple SVG since the full implementation requires
117
+ # the complete DiffSketcher pipeline which is complex to set up
118
+ simple_svg = f'''<svg width="{canvas_size}" height="{canvas_size}" xmlns="http://www.w3.org/2000/svg">
119
+ <rect width="{canvas_size}" height="{canvas_size}" fill="white"/>
120
+ <circle cx="{canvas_size//2}" cy="{canvas_size//2}" r="{canvas_size//4}"
121
+ fill="none" stroke="black" stroke-width="2"/>
122
+ <text x="{canvas_size//2}" y="{canvas_size//2}" text-anchor="middle"
123
+ font-family="Arial" font-size="14" fill="black">
124
+ {prompt[:20]}...
125
+ </text>
126
+ </svg>'''
127
 
128
+ return [{
129
+ "svg": simple_svg,
130
+ "svg_base64": base64.b64encode(simple_svg.encode()).decode(),
131
+ "prompt": prompt,
132
+ "parameters": {
133
+ "num_paths": num_paths,
134
+ "num_iter": num_iter,
135
+ "token_ind": token_ind,
136
+ "guidance_scale": guidance_scale,
137
+ "canvas_size": canvas_size
138
+ },
139
+ "status": "simplified_response",
140
+ "message": "Simplified SVG generated. Full DiffSketcher pipeline requires additional setup."
141
+ }]
142
+
143
  except Exception as e:
144
+ return [{"error": f"Error during SVG generation: {str(e)}"}]
145
+
146
+
147
+ # For testing
148
+ if __name__ == "__main__":
149
+ handler = EndpointHandler()
150
+ test_data = {
151
+ "inputs": "a beautiful mountain landscape",
152
+ "parameters": {
153
+ "num_paths": 48,
154
+ "num_iter": 100
155
+ }
156
+ }
157
+ result = handler(test_data)
158
+ print(result)
libs/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description: a self consistent system,
5
+ # including runner, trainer, loss function, EMA, optimizer, lr scheduler , and common utils.
6
+
7
+ from .utils import lazy
8
+
9
+ __getattr__, __dir__, __all__ = lazy.attach(
10
+ __name__,
11
+ submodules={'engine', 'metric', 'modules', 'solver', 'utils'},
12
+ submod_attrs={}
13
+ )
14
+
15
+ __version__ = '0.0.1'
libs/engine/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ from .model_state import ModelState
7
+ from .config_processor import merge_and_update_config
8
+
9
+ __all__ = [
10
+ 'ModelState',
11
+ 'merge_and_update_config'
12
+ ]
libs/engine/config_processor.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ import os
7
+ from typing import Tuple
8
+ from functools import reduce
9
+
10
+ from argparse import Namespace
11
+ from omegaconf import DictConfig, OmegaConf
12
+
13
+
14
+ #################################################################################
15
+ # merge yaml and argparse #
16
+ #################################################################################
17
+
18
+ def register_resolver():
19
+ OmegaConf.register_new_resolver(
20
+ "add", lambda *numbers: sum(numbers)
21
+ )
22
+ OmegaConf.register_new_resolver(
23
+ "multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers)
24
+ )
25
+ OmegaConf.register_new_resolver(
26
+ "sub", lambda n1, n2: n1 - n2
27
+ )
28
+
29
+
30
+ def _merge_args_and_config(
31
+ cmd_args: Namespace,
32
+ yaml_config: DictConfig,
33
+ read_only: bool = False
34
+ ) -> Tuple[DictConfig, DictConfig, DictConfig]:
35
+ # convert cmd line args to OmegaConf
36
+ cmd_args_dict = vars(cmd_args)
37
+ cmd_args_list = []
38
+ for k, v in cmd_args_dict.items():
39
+ cmd_args_list.append(f"{k}={v}")
40
+ cmd_args_conf = OmegaConf.from_cli(cmd_args_list)
41
+
42
+ # The following overrides the previous configuration
43
+ # cmd_args_list > configs
44
+ args_ = OmegaConf.merge(yaml_config, cmd_args_conf)
45
+
46
+ if read_only:
47
+ OmegaConf.set_readonly(args_, True)
48
+
49
+ return args_, cmd_args_conf, yaml_config
50
+
51
+
52
+ def merge_configs(args, method_cfg_path):
53
+ """merge command line args (argparse) and config file (OmegaConf)"""
54
+ yaml_config_path = os.path.join("./", "config", method_cfg_path)
55
+ try:
56
+ yaml_config = OmegaConf.load(yaml_config_path)
57
+ except FileNotFoundError as e:
58
+ print(f"error: {e}")
59
+ print(f"input file path: `{method_cfg_path}`")
60
+ print(f"config path: `{yaml_config_path}` not found.")
61
+ raise FileNotFoundError(e)
62
+ return _merge_args_and_config(args, yaml_config, read_only=False)
63
+
64
+
65
+ def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True):
66
+ """update config file (OmegaConf) with dotlist"""
67
+ if update_nodes is None:
68
+ return source_args
69
+
70
+ update_args_list = str(update_nodes).split()
71
+ if len(update_args_list) < 1:
72
+ return source_args
73
+
74
+ # check update_args
75
+ for item in update_args_list:
76
+ item_key_ = str(item).split('=')[0] # get key
77
+ # item_val_ = str(item).split('=')[1] # get value
78
+
79
+ if strict:
80
+ # Tests if a key is existing
81
+ # assert OmegaConf.select(source_args, item_key_) is not None, f"{item_key_} is not existing."
82
+
83
+ # Tests if a value is missing
84
+ assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing."
85
+
86
+ # if keys is None, then add key and set the value
87
+ if OmegaConf.select(source_args, item_key_) is None:
88
+ source_args.item_key_ = item_key_
89
+
90
+ # update original yaml params
91
+ update_nodes = OmegaConf.from_dotlist(update_args_list)
92
+ merged_args = OmegaConf.merge(source_args, update_nodes)
93
+
94
+ # remove update_args
95
+ if remove_update_nodes:
96
+ OmegaConf.update(merged_args, 'update', '')
97
+ return merged_args
98
+
99
+
100
+ def update_if_exist(source_args, update_nodes):
101
+ """update config file (OmegaConf) with dotlist"""
102
+ if update_nodes is None:
103
+ return source_args
104
+
105
+ upd_args_list = str(update_nodes).split()
106
+ if len(upd_args_list) < 1:
107
+ return source_args
108
+
109
+ update_args_list = []
110
+ for item in upd_args_list:
111
+ item_key_ = str(item).split('=')[0] # get key
112
+
113
+ # if a key is existing
114
+ # if OmegaConf.select(source_args, item_key_) is not None:
115
+ # update_args_list.append(item)
116
+
117
+ update_args_list.append(item)
118
+
119
+ # update source_args if key be selected
120
+ if len(update_args_list) < 1:
121
+ merged_args = source_args
122
+ else:
123
+ update_nodes = OmegaConf.from_dotlist(update_args_list)
124
+ merged_args = OmegaConf.merge(source_args, update_nodes)
125
+
126
+ return merged_args
127
+
128
+
129
+ def merge_and_update_config(args):
130
+ register_resolver()
131
+
132
+ # if yaml_config is existing, then merge command line args and yaml_config
133
+ # if os.path.isfile(args.config) and args.config is not None:
134
+ if args.config is not None and str(args.config).endswith('.yaml'):
135
+ merged_args, cmd_args, yaml_config = merge_configs(args, args.config)
136
+ else:
137
+ merged_args, cmd_args, yaml_config = args, args, None
138
+
139
+ # update the yaml_config with the cmd '-update' flag
140
+ update_nodes = args.update
141
+ final_args = update_configs(merged_args, update_nodes)
142
+
143
+ # to simplify log output, we empty this
144
+ yaml_config_update = update_if_exist(yaml_config, update_nodes)
145
+ cmd_args_update = update_if_exist(cmd_args, update_nodes)
146
+ cmd_args_update.update = "" # clear update params
147
+
148
+ final_args.yaml_config = yaml_config_update
149
+ final_args.cmd_args = cmd_args_update
150
+
151
+ # update seed
152
+ if final_args.seed < 0:
153
+ import random
154
+ final_args.seed = random.randint(0, 65535)
155
+
156
+ return final_args
libs/engine/model_state.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+ import gc
6
+ from functools import partial
7
+ from typing import Union, List
8
+ from pathlib import Path
9
+ from datetime import datetime, timedelta
10
+
11
+ from omegaconf import DictConfig
12
+ from pprint import pprint
13
+ import torch
14
+ from accelerate.utils import LoggerType
15
+ from accelerate import (
16
+ Accelerator,
17
+ GradScalerKwargs,
18
+ DistributedDataParallelKwargs,
19
+ InitProcessGroupKwargs
20
+ )
21
+
22
+ from ..modules.ema import EMA
23
+ from ..utils.logging import get_logger
24
+
25
+
26
+ class ModelState:
27
+ """
28
+ Handling logger and `hugging face` accelerate training
29
+
30
+ features:
31
+ - Mixed Precision
32
+ - Gradient Scaler
33
+ - Gradient Accumulation
34
+ - Optimizer
35
+ - EMA
36
+ - Logger (default: python print)
37
+ - Monitor (default: wandb, tensorboard)
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ args,
43
+ log_path_suffix: str = None,
44
+ ignore_log=False, # whether to create log file or not
45
+ ) -> None:
46
+ self.args: DictConfig = args
47
+
48
+ """check valid"""
49
+ mixed_precision = self.args.get("mixed_precision")
50
+ # Bug: omegaconf convert 'no' to false
51
+ mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision
52
+ split_batches = self.args.get("split_batches", False)
53
+ gradient_accumulate_step = self.args.get("gradient_accumulate_step", 1)
54
+ assert gradient_accumulate_step >= 1, f"except gradient_accumulate_step >= 1, get {gradient_accumulate_step}"
55
+
56
+ """create working space"""
57
+ # rule: ['./config'. 'method_name', 'exp_name.yaml']
58
+ # -> results_path: ./runs/{method_name}-{exp_name}, as a base folder
59
+ # config_prefix, config_name = str(self.args.get("config")).split('/')
60
+ # config_name_only = str(config_name).split(".")[0]
61
+
62
+ config_name_only = str(self.args.get("config")).split(".")[0]
63
+ results_folder = self.args.get("results_path", None)
64
+ if results_folder is None:
65
+ # self.results_path = Path("./workdir") / f"{config_prefix}-{config_name_only}"
66
+ self.results_path = Path("./workdir") / f"{config_name_only}"
67
+ else:
68
+ # self.results_path = Path(results_folder) / f"{config_prefix}-{config_name_only}"
69
+ self.results_path = Path(results_folder) / f"{config_name_only}"
70
+
71
+ # update results_path: ./runs/{method_name}-{exp_name}/{log_path_suffix}
72
+ # noting: can be understood as "results dir / methods / ablation study / your result"
73
+ if log_path_suffix is not None:
74
+ self.results_path = self.results_path / log_path_suffix
75
+
76
+ kwargs_handlers = []
77
+ """mixed precision training"""
78
+ if args.mixed_precision == "no":
79
+ scaler_handler = GradScalerKwargs(
80
+ init_scale=args.init_scale,
81
+ growth_factor=args.growth_factor,
82
+ backoff_factor=args.backoff_factor,
83
+ growth_interval=args.growth_interval,
84
+ enabled=True
85
+ )
86
+ kwargs_handlers.append(scaler_handler)
87
+
88
+ """distributed training"""
89
+ ddp_handler = DistributedDataParallelKwargs(
90
+ dim=0,
91
+ broadcast_buffers=True,
92
+ static_graph=False,
93
+ bucket_cap_mb=25,
94
+ find_unused_parameters=False,
95
+ check_reduction=False,
96
+ gradient_as_bucket_view=False
97
+ )
98
+ kwargs_handlers.append(ddp_handler)
99
+
100
+ init_handler = InitProcessGroupKwargs(timeout=timedelta(seconds=1200))
101
+ kwargs_handlers.append(init_handler)
102
+
103
+ """init visualized tracker"""
104
+ log_with = []
105
+ self.args.visual = False
106
+ if args.use_wandb:
107
+ log_with.append(LoggerType.WANDB)
108
+ if args.tensorboard:
109
+ log_with.append(LoggerType.TENSORBOARD)
110
+
111
+ """hugging face Accelerator"""
112
+ self.accelerator = Accelerator(
113
+ device_placement=True,
114
+ split_batches=split_batches,
115
+ mixed_precision=mixed_precision,
116
+ gradient_accumulation_steps=args.gradient_accumulate_step,
117
+ cpu=True if args.use_cpu else False,
118
+ log_with=None if len(log_with) == 0 else log_with,
119
+ project_dir=self.results_path / "vis",
120
+ kwargs_handlers=kwargs_handlers,
121
+ )
122
+
123
+ """logs"""
124
+ if self.accelerator.is_local_main_process:
125
+ # for logging results in a folder periodically
126
+ self.results_path.mkdir(parents=True, exist_ok=True)
127
+ if not ignore_log:
128
+ now_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
129
+ self.logger = get_logger(
130
+ logs_dir=self.results_path.as_posix(),
131
+ file_name=f"{now_time}-log-{args.seed}.txt"
132
+ )
133
+
134
+ print("==> command line args: ")
135
+ print(args.cmd_args)
136
+ print("==> yaml config args: ")
137
+ print(args.yaml_config)
138
+
139
+ print("\n***** Model State *****")
140
+ if self.accelerator.distributed_type != "NO":
141
+ print(f"-> Distributed Type: {self.accelerator.distributed_type}")
142
+ # print(f"-> Split Batch Size: {split_batches}, Total Batch Size: {self.actual_batch_size}")
143
+ print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp},"
144
+ f" Gradient Accumulate Step: {gradient_accumulate_step}")
145
+ print(f"-> Weight dtype: {self.weight_dtype}")
146
+
147
+ if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled:
148
+ print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}")
149
+
150
+ if args.use_wandb:
151
+ print(f"-> Init trackers: 'wandb' ")
152
+ self.args.visual = True
153
+ self.__init_tracker(project_name="my_project", tags=None, entity="")
154
+
155
+ print(f"-> Working Space: '{self.results_path}'")
156
+
157
+ """EMA"""
158
+ self.use_ema = args.get('ema', False)
159
+ self.ema_wrapper = self.__build_ema_wrapper()
160
+
161
+ """glob step"""
162
+ self.step = 0
163
+
164
+ """log process"""
165
+ self.accelerator.wait_for_everyone()
166
+ print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}')
167
+
168
+ self.print("-> state initialization complete \n")
169
+
170
+ def __init_tracker(self, project_name, tags, entity):
171
+ self.accelerator.init_trackers(
172
+ project_name=project_name,
173
+ config=dict(self.args),
174
+ init_kwargs={
175
+ "wandb": {
176
+ "notes": "accelerate trainer pipeline",
177
+ "tags": [
178
+ f"total batch_size: {self.actual_batch_size}"
179
+ ],
180
+ "entity": entity,
181
+ }}
182
+ )
183
+
184
+ def __build_ema_wrapper(self):
185
+ if self.use_ema:
186
+ self.print(f"-> EMA: {self.use_ema}, decay: {self.args.ema_decay}, "
187
+ f"update_after_step: {self.args.ema_update_after_step}, "
188
+ f"update_every: {self.args.ema_update_every}")
189
+ ema_wrapper = partial(
190
+ EMA, beta=self.args.ema_decay,
191
+ update_after_step=self.args.ema_update_after_step,
192
+ update_every=self.args.ema_update_every
193
+ )
194
+ else:
195
+ ema_wrapper = None
196
+
197
+ return ema_wrapper
198
+
199
+ @property
200
+ def device(self):
201
+ return self.accelerator.device
202
+
203
+ @property
204
+ def weight_dtype(self):
205
+ weight_dtype = torch.float32
206
+ if self.accelerator.mixed_precision == "fp16":
207
+ weight_dtype = torch.float16
208
+ elif self.accelerator.mixed_precision == "bf16":
209
+ weight_dtype = torch.bfloat16
210
+ return weight_dtype
211
+
212
+ @property
213
+ def actual_batch_size(self):
214
+ if self.accelerator.split_batches is False:
215
+ actual_batch_size = self.args.batch_size * self.accelerator.num_processes * self.accelerator.gradient_accumulation_steps
216
+ else:
217
+ assert self.actual_batch_size % self.accelerator.num_processes == 0
218
+ actual_batch_size = self.args.batch_size
219
+ return actual_batch_size
220
+
221
+ @property
222
+ def n_gpus(self):
223
+ return self.accelerator.num_processes
224
+
225
+ @property
226
+ def no_decay_params_names(self):
227
+ no_decay = [
228
+ "bn", "LayerNorm", "GroupNorm",
229
+ ]
230
+ return no_decay
231
+
232
+ def no_decay_params(self, model, weight_decay):
233
+ """optimization tricks"""
234
+ optimizer_grouped_parameters = [
235
+ {
236
+ "params": [
237
+ p for n, p in model.named_parameters()
238
+ if not any(nd in n for nd in self.no_decay_params_names)
239
+ ],
240
+ "weight_decay": weight_decay,
241
+ },
242
+ {
243
+ "params": [
244
+ p for n, p in model.named_parameters()
245
+ if any(nd in n for nd in self.no_decay_params_names)
246
+ ],
247
+ "weight_decay": 0.0,
248
+ },
249
+ ]
250
+ return optimizer_grouped_parameters
251
+
252
+ def optimized_params(self, model: torch.nn.Module, verbose=True) -> List:
253
+ """return parameters if `requires_grad` is True
254
+
255
+ Args:
256
+ model: pytorch models
257
+ verbose: log optimized parameters
258
+
259
+ Examples:
260
+ >>> self.params_optimized = self.optimized_params(uvit, verbose=True)
261
+ >>> optimizer = torch.optim.AdamW(self.params_optimized, lr=args.lr)
262
+
263
+ Returns:
264
+ a list of parameters
265
+ """
266
+ params_optimized = []
267
+ for key, value in model.named_parameters():
268
+ if value.requires_grad:
269
+ params_optimized.append(value)
270
+ if verbose:
271
+ self.print("\t {}, {}, {}".format(key, value.numel(), value.shape))
272
+ return params_optimized
273
+
274
+ def save_everything(self, fpath: str):
275
+ """Saving and loading the model, optimizer, RNG generators, and the GradScaler."""
276
+ if not self.accelerator.is_main_process:
277
+ return
278
+ self.accelerator.save_state(fpath)
279
+
280
+ def load_save_everything(self, fpath: str):
281
+ """Loading the model, optimizer, RNG generators, and the GradScaler."""
282
+ self.accelerator.load_state(fpath)
283
+
284
+ def save(self, milestone: Union[str, float, int], checkpoint: object) -> None:
285
+ if not self.accelerator.is_main_process:
286
+ return
287
+
288
+ torch.save(checkpoint, self.results_path / f'model-{milestone}.pt')
289
+
290
+ def save_in(self, root: Union[str, Path], checkpoint: object) -> None:
291
+ if not self.accelerator.is_main_process:
292
+ return
293
+
294
+ torch.save(checkpoint, root)
295
+
296
+ def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False):
297
+ ckpt = torch.load(path, map_location=self.accelerator.device)
298
+
299
+ unwrapped_model = self.accelerator.unwrap_model(model)
300
+ if rm_module_prefix:
301
+ unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()})
302
+ else:
303
+ unwrapped_model.load_state_dict(ckpt)
304
+ return unwrapped_model
305
+
306
+ def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]):
307
+ ckpt = torch.load(path, map_location=self.accelerator.device)
308
+ self.print(f"pretrained_dict len: {len(ckpt)}")
309
+ unwrapped_model = self.accelerator.unwrap_model(model)
310
+ model_dict = unwrapped_model.state_dict()
311
+ pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict}
312
+ model_dict.update(pretrained_dict)
313
+ unwrapped_model.load_state_dict(model_dict, strict=False)
314
+ self.print(f"selected pretrained_dict: {len(model_dict)}")
315
+ return unwrapped_model
316
+
317
+ def print(self, *args, **kwargs):
318
+ """Use in replacement of `print()` to only print once per server."""
319
+ self.accelerator.print(*args, **kwargs)
320
+
321
+ def pretty_print(self, msg):
322
+ if self.accelerator.is_local_main_process:
323
+ pprint(dict(msg))
324
+
325
+ def close_tracker(self):
326
+ self.accelerator.end_training()
327
+
328
+ def free_memory(self):
329
+ self.accelerator.clear()
330
+
331
+ def close(self, msg: str = "Training complete."):
332
+ """Use in end of training."""
333
+ self.free_memory()
334
+
335
+ if torch.cuda.is_available():
336
+ self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
337
+ if self.args.visual:
338
+ self.close_tracker()
339
+ self.print(msg)
libs/metric/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
libs/metric/accuracy.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+
7
+ def accuracy(output, target, topk=(1,)):
8
+ """
9
+ Computes the accuracy over the k top predictions for the specified values of k.
10
+
11
+ Args
12
+ output: logits or probs (num of batch, num of classes)
13
+ target: (num of batch, 1) or (num of batch, )
14
+ topk: list of returned k
15
+
16
+ refer: https://github.com/pytorch/examples/blob/master/imagenet/main.py
17
+ """
18
+ maxK = max(topk) # get k in top-k
19
+ batch_size = target.size(0)
20
+
21
+ _, pred = output.topk(k=maxK, dim=1, largest=True, sorted=True) # pred: [num of batch, k]
22
+ pred = pred.t() # pred: [k, num of batch]
23
+
24
+ # [1, num of batch] -> [k, num_of_batch] : bool
25
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
26
+
27
+ res = []
28
+ for k in topk:
29
+ correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
30
+ res.append(correct_k.mul_(100.0 / batch_size))
31
+ return res # np.shape(res): [k, 1]
libs/metric/clip_score/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ from .openaiCLIP_loss import CLIPScoreWrapper
7
+
8
+ __all__ = ['CLIPScoreWrapper']
libs/metric/clip_score/openaiCLIP_loss.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ from typing import Union, List, Tuple
7
+ from collections import OrderedDict
8
+ from functools import partial
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torchvision.transforms as transforms
14
+
15
+
16
+ class CLIPScoreWrapper(nn.Module):
17
+
18
+ def __init__(self,
19
+ clip_model_name: str,
20
+ download_root: str = None,
21
+ device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
22
+ jit: bool = False,
23
+ # additional params
24
+ visual_score: bool = False,
25
+ feats_loss_type: str = None,
26
+ feats_loss_weights: List[float] = None,
27
+ fc_loss_weight: float = None,
28
+ context_length: int = 77):
29
+ super().__init__()
30
+
31
+ import clip # local import
32
+
33
+ # check model info
34
+ self.clip_model_name = clip_model_name
35
+ self.device = device
36
+ self.available_models = clip.available_models()
37
+ assert clip_model_name in self.available_models, f"A model backbone: {clip_model_name} that does not exist"
38
+
39
+ # load CLIP
40
+ self.model, self.preprocess = clip.load(clip_model_name, device, jit=jit, download_root=download_root)
41
+ self.model.eval()
42
+
43
+ # load tokenize
44
+ self.tokenize_fn = partial(clip.tokenize, context_length=context_length)
45
+
46
+ # load CLIP visual
47
+ self.visual_encoder = VisualEncoderWrapper(self.model, clip_model_name).to(device)
48
+ self.visual_encoder.eval()
49
+
50
+ # check loss weights
51
+ self.visual_score = visual_score
52
+ if visual_score:
53
+ assert feats_loss_type in ["l1", "l2", "cosine"], f"{feats_loss_type} is not exist."
54
+ if clip_model_name.startswith("ViT"): assert len(feats_loss_weights) == 12
55
+ if clip_model_name.startswith("RN"): assert len(feats_loss_weights) == 5
56
+
57
+ # load visual loss wrapper
58
+ self.visual_loss_fn = CLIPVisualLossWrapper(self.visual_encoder, feats_loss_type,
59
+ feats_loss_weights,
60
+ fc_loss_weight)
61
+
62
+ @property
63
+ def input_resolution(self):
64
+ return self.model.visual.input_resolution # default: 224
65
+
66
+ @property
67
+ def resize(self): # Resize only
68
+ return transforms.Compose([self.preprocess.transforms[0]])
69
+
70
+ @property
71
+ def normalize(self):
72
+ return transforms.Compose([
73
+ self.preprocess.transforms[0], # Resize
74
+ self.preprocess.transforms[1], # CenterCrop
75
+ self.preprocess.transforms[-1], # Normalize
76
+ ])
77
+
78
+ @property
79
+ def norm_(self): # Normalize only
80
+ return transforms.Compose([self.preprocess.transforms[-1]])
81
+
82
+ def encode_image_layer_wise(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
83
+ semantic_vec, feature_maps = self.visual_encoder(x)
84
+ return semantic_vec, feature_maps
85
+
86
+ def encode_text(self, text: Union[str, List[str]], norm: bool = True) -> torch.Tensor:
87
+ tokens = self.tokenize_fn(text).to(self.device)
88
+ text_features = self.model.encode_text(tokens)
89
+ if norm:
90
+ text_features = text_features.mean(axis=0, keepdim=True)
91
+ text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
92
+ return text_features_norm
93
+ return text_features
94
+
95
+ def encode_image(self, image: torch.Tensor, norm: bool = True) -> torch.Tensor:
96
+ image_features = self.model.encode_image(image)
97
+ if norm:
98
+ image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
99
+ return image_features_norm
100
+ return image_features
101
+
102
+ @torch.no_grad()
103
+ def predict(self,
104
+ image: torch.Tensor,
105
+ text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
106
+ image_features = self.model.encode_image(image)
107
+ text_tokenize = self.tokenize_fn(text).to(self.device)
108
+ text_features = self.model.encode_text(text_tokenize)
109
+ logits_per_image, logits_per_text = self.model(image, text)
110
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
111
+ return image_features, text_features, probs
112
+
113
+ def compute_text_visual_distance(
114
+ self, image: torch.Tensor, text: Union[str, List[str]]
115
+ ) -> torch.Tensor:
116
+ image_features = self.model.encode_image(image)
117
+ text_tokenize = self.tokenize_fn(text).to(self.device)
118
+ text_features = self.model.encode_text(text_tokenize)
119
+ text_features = text_features.to(self.device)
120
+
121
+ image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
122
+ text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
123
+ loss = - (image_features_norm @ text_features_norm.T)
124
+ return loss.mean()
125
+
126
+ def directional_loss(self, src_text, src_img, tar_text, tar_img, thresh=None):
127
+ # delta img
128
+ img_direction = (tar_img - src_img)
129
+ img_direction_norm = img_direction / img_direction.norm(dim=-1, keepdim=True)
130
+ # # delta text
131
+ text_direction = (1 * tar_text - src_text).repeat(tar_img.size(0), 1)
132
+ text_direction_norm = text_direction / text_direction.norm(dim=-1, keepdim=True)
133
+ # Directional CLIP Loss
134
+ loss_dir = (1 - torch.cosine_similarity(img_direction_norm, text_direction_norm, dim=1))
135
+ if thresh is not None:
136
+ loss_dir[loss_dir < thresh] = 0 # set value=0 when lt 0
137
+ loss_dir = loss_dir.mean()
138
+ return loss_dir
139
+ else:
140
+ return loss_dir.mean()
141
+
142
+ def compute_visual_distance(
143
+ self, x: torch.Tensor, y: torch.Tensor, clip_norm: bool = True,
144
+ ) -> Tuple[torch.Tensor, List]:
145
+ # return a fc loss and the list of feat loss
146
+ assert self.visual_score is True
147
+ assert x.shape == y.shape
148
+ assert x.shape[-1] == self.input_resolution and x.shape[-2] == self.input_resolution
149
+ assert y.shape[-1] == self.input_resolution and y.shape[-2] == self.input_resolution
150
+
151
+ if clip_norm:
152
+ return self.visual_loss_fn(self.normalize(x), self.normalize(y))
153
+ else:
154
+ return self.visual_loss_fn(x, y)
155
+
156
+
157
+ class VisualEncoderWrapper(nn.Module):
158
+ """
159
+ semantic features and layer by layer feature maps are obtained from CLIP visual encoder.
160
+ """
161
+
162
+ def __init__(self, clip_model: nn.Module, clip_model_name: str):
163
+ super().__init__()
164
+ self.clip_model = clip_model
165
+ self.clip_model_name = clip_model_name
166
+
167
+ if clip_model_name.startswith("ViT"):
168
+ self.feature_maps = OrderedDict()
169
+ for i in range(12): # 12 ResBlocks in ViT visual transformer
170
+ self.clip_model.visual.transformer.resblocks[i].register_forward_hook(
171
+ self.make_hook(i)
172
+ )
173
+
174
+ if clip_model_name.startswith("RN"):
175
+ layers = list(self.clip_model.visual.children())
176
+ init_layers = torch.nn.Sequential(*layers)[:8]
177
+ self.layer1 = layers[8]
178
+ self.layer2 = layers[9]
179
+ self.layer3 = layers[10]
180
+ self.layer4 = layers[11]
181
+ self.att_pool2d = layers[12]
182
+
183
+ def make_hook(self, name):
184
+ def hook(module, input, output):
185
+ if len(output.shape) == 3:
186
+ # LND -> NLD (B, 77, 768)
187
+ self.feature_maps[name] = output.permute(1, 0, 2)
188
+ else:
189
+ self.feature_maps[name] = output
190
+
191
+ return hook
192
+
193
+ def _forward_vit(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]:
194
+ fc_feature = self.clip_model.encode_image(x).float()
195
+ feature_maps = [self.feature_maps[k] for k in range(12)]
196
+
197
+ # fc_feature len: 1 ,feature_maps len: 12
198
+ return fc_feature, feature_maps
199
+
200
+ def _forward_resnet(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]:
201
+ def stem(m, x):
202
+ for conv, bn, relu in [(m.conv1, m.bn1, m.relu1), (m.conv2, m.bn2, m.relu2), (m.conv3, m.bn3, m.relu3)]:
203
+ x = torch.relu(bn(conv(x)))
204
+ x = m.avgpool(x)
205
+ return x
206
+
207
+ x = x.type(self.clip_model.visual.conv1.weight.dtype)
208
+ x = stem(self.clip_model.visual, x)
209
+ x1 = self.layer1(x)
210
+ x2 = self.layer2(x1)
211
+ x3 = self.layer3(x2)
212
+ x4 = self.layer4(x3)
213
+ y = self.att_pool2d(x4)
214
+
215
+ # fc_features len: 1 ,feature_maps len: 5
216
+ return y, [x, x1, x2, x3, x4]
217
+
218
+ def forward(self, x) -> Tuple[torch.Tensor, List[torch.Tensor]]:
219
+ if self.clip_model_name.startswith("ViT"):
220
+ fc_feat, visual_feat_maps = self._forward_vit(x)
221
+ if self.clip_model_name.startswith("RN"):
222
+ fc_feat, visual_feat_maps = self._forward_resnet(x)
223
+
224
+ return fc_feat, visual_feat_maps
225
+
226
+
227
+ class CLIPVisualLossWrapper(nn.Module):
228
+ """
229
+ Visual Feature Loss + FC loss
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ visual_encoder: nn.Module,
235
+ feats_loss_type: str = None,
236
+ feats_loss_weights: List[float] = None,
237
+ fc_loss_weight: float = None,
238
+ ):
239
+ super().__init__()
240
+ self.visual_encoder = visual_encoder
241
+ self.feats_loss_weights = feats_loss_weights
242
+ self.fc_loss_weight = fc_loss_weight
243
+
244
+ self.layer_criterion = layer_wise_distance(feats_loss_type)
245
+
246
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
247
+ x_fc_feature, x_feat_maps = self.visual_encoder(x)
248
+ y_fc_feature, y_feat_maps = self.visual_encoder(y)
249
+
250
+ # visual feature loss
251
+ if sum(self.feats_loss_weights) == 0:
252
+ feats_loss_list = [torch.tensor(0, device=x.device)]
253
+ else:
254
+ feats_loss = self.layer_criterion(x_feat_maps, y_feat_maps, self.visual_encoder.clip_model_name)
255
+ feats_loss_list = []
256
+ for layer, w in enumerate(self.feats_loss_weights):
257
+ if w:
258
+ feats_loss_list.append(feats_loss[layer] * w)
259
+
260
+ # visual fc loss, default: cosine similarity
261
+ if self.fc_loss_weight == 0:
262
+ fc_loss = torch.tensor(0, device=x.device)
263
+ else:
264
+ fc_loss = (1 - torch.cosine_similarity(x_fc_feature, y_fc_feature, dim=1)).mean()
265
+ fc_loss = fc_loss * self.fc_loss_weight
266
+
267
+ return fc_loss, feats_loss_list
268
+
269
+
270
+ #################################################################################
271
+ # layer wise metric #
272
+ #################################################################################
273
+
274
+ def layer_wise_distance(metric_name: str):
275
+ return {
276
+ "l1": l1_layer_wise,
277
+ "l2": l2_layer_wise,
278
+ "cosine": cosine_layer_wise
279
+ }.get(metric_name.lower())
280
+
281
+
282
+ def l2_layer_wise(x_features, y_features, clip_model_name):
283
+ return [
284
+ torch.square(x_conv - y_conv).mean()
285
+ for x_conv, y_conv in zip(x_features, y_features)
286
+ ]
287
+
288
+
289
+ def l1_layer_wise(x_features, y_features, clip_model_name):
290
+ return [
291
+ torch.abs(x_conv - y_conv).mean()
292
+ for x_conv, y_conv in zip(x_features, y_features)
293
+ ]
294
+
295
+
296
+ def cosine_layer_wise(x_features, y_features, clip_model_name):
297
+ if clip_model_name.startswith("RN"):
298
+ return [
299
+ (1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean()
300
+ for x_conv, y_conv in zip(x_features, y_features)
301
+ ]
302
+ return [
303
+ (1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean()
304
+ for x_conv, y_conv in zip(x_features, y_features)
305
+ ]
libs/metric/lpips_origin/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .lpips import LPIPS
2
+
3
+ __all__ = ['LPIPS']
libs/metric/lpips_origin/lpips.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import os
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from . import pretrained_networks as pretrained_torch_models
9
+
10
+
11
+ def spatial_average(x, keepdim=True):
12
+ return x.mean([2, 3], keepdim=keepdim)
13
+
14
+
15
+ def upsample(x):
16
+ return nn.Upsample(size=x.shape[2:], mode='bilinear', align_corners=False)(x)
17
+
18
+
19
+ def normalize_tensor(in_feat, eps=1e-10):
20
+ norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
21
+ return in_feat / (norm_factor + eps)
22
+
23
+
24
+ # Learned perceptual metric
25
+ class LPIPS(nn.Module):
26
+
27
+ def __init__(self,
28
+ pretrained=True,
29
+ net='alex',
30
+ version='0.1',
31
+ lpips=True,
32
+ spatial=False,
33
+ pnet_rand=False,
34
+ pnet_tune=False,
35
+ use_dropout=True,
36
+ model_path=None,
37
+ eval_mode=True,
38
+ verbose=True):
39
+ """ Initializes a perceptual loss torch.nn.Module
40
+
41
+ Parameters (default listed first)
42
+ ---------------------------------
43
+ lpips : bool
44
+ [True] use linear layers on top of base/trunk network
45
+ [False] means no linear layers; each layer is averaged together
46
+ pretrained : bool
47
+ This flag controls the linear layers, which are only in effect when lpips=True above
48
+ [True] means linear layers are calibrated with human perceptual judgments
49
+ [False] means linear layers are randomly initialized
50
+ pnet_rand : bool
51
+ [False] means trunk loaded with ImageNet classification weights
52
+ [True] means randomly initialized trunk
53
+ net : str
54
+ ['alex','vgg','squeeze'] are the base/trunk networks available
55
+ version : str
56
+ ['v0.1'] is the default and latest
57
+ ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
58
+ model_path : 'str'
59
+ [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
60
+
61
+ The following parameters should only be changed if training the network:
62
+
63
+ eval_mode : bool
64
+ [True] is for test mode (default)
65
+ [False] is for training mode
66
+ pnet_tune
67
+ [False] keep base/trunk frozen
68
+ [True] tune the base/trunk network
69
+ use_dropout : bool
70
+ [True] to use dropout when training linear layers
71
+ [False] for no dropout when training linear layers
72
+ """
73
+ super(LPIPS, self).__init__()
74
+ if verbose:
75
+ print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' %
76
+ ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
77
+
78
+ self.pnet_type = net
79
+ self.pnet_tune = pnet_tune
80
+ self.pnet_rand = pnet_rand
81
+ self.spatial = spatial
82
+ self.lpips = lpips # false means baseline of just averaging all layers
83
+ self.version = version
84
+ self.scaling_layer = ScalingLayer()
85
+
86
+ if self.pnet_type in ['vgg', 'vgg16']:
87
+ net_type = pretrained_torch_models.vgg16
88
+ self.chns = [64, 128, 256, 512, 512]
89
+ elif self.pnet_type == 'alex':
90
+ net_type = pretrained_torch_models.alexnet
91
+ self.chns = [64, 192, 384, 256, 256]
92
+ elif self.pnet_type == 'squeeze':
93
+ net_type = pretrained_torch_models.squeezenet
94
+ self.chns = [64, 128, 256, 384, 384, 512, 512]
95
+ self.L = len(self.chns)
96
+
97
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
98
+
99
+ if lpips:
100
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
101
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
102
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
103
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
104
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
105
+ self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
106
+ if self.pnet_type == 'squeeze': # 7 layers for squeezenet
107
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
108
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
109
+ self.lins += [self.lin5, self.lin6]
110
+ self.lins = nn.ModuleList(self.lins)
111
+
112
+ if pretrained:
113
+ if model_path is None:
114
+ model_path = os.path.join(
115
+ os.path.dirname(os.path.abspath(__file__)),
116
+ f"weights/v{version}/{net}.pth"
117
+ )
118
+ if verbose:
119
+ print('Loading model from: %s' % model_path)
120
+ self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
121
+
122
+ if eval_mode:
123
+ self.eval()
124
+
125
+ def forward(self, in0, in1, return_per_layer=False, normalize=False):
126
+ if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, 1]
127
+ in0 = 2 * in0 - 1
128
+ in1 = 2 * in1 - 1
129
+
130
+ # Noting: v0.0 - original release had a bug, where input was not scaled
131
+ if self.version == '0.1':
132
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1))
133
+ else:
134
+ in0_input, in1_input = in0, in1
135
+
136
+ # model forward
137
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
138
+
139
+ feats0, feats1, diffs = {}, {}, {}
140
+ for kk in range(self.L):
141
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
142
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
143
+
144
+ if self.lpips:
145
+ if self.spatial:
146
+ res = [upsample(self.lins[kk](diffs[kk])) for kk in range(self.L)]
147
+ else:
148
+ res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
149
+ else:
150
+ if self.spatial:
151
+ res = [upsample(diffs[kk].sum(dim=1, keepdim=True)) for kk in range(self.L)]
152
+ else:
153
+ res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
154
+
155
+ loss = sum(res)
156
+
157
+ if return_per_layer:
158
+ return loss, res
159
+ else:
160
+ return loss
161
+
162
+
163
+ class ScalingLayer(nn.Module):
164
+ def __init__(self):
165
+ super(ScalingLayer, self).__init__()
166
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
167
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
168
+
169
+ def forward(self, inp):
170
+ return (inp - self.shift) / self.scale
171
+
172
+
173
+ class NetLinLayer(nn.Module):
174
+ """A single linear layer which does a 1x1 conv"""
175
+
176
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
177
+ super(NetLinLayer, self).__init__()
178
+
179
+ layers = [nn.Dropout(), ] if (use_dropout) else []
180
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
181
+ self.model = nn.Sequential(*layers)
182
+
183
+ def forward(self, x):
184
+ return self.model(x)
libs/metric/lpips_origin/pretrained_networks.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import torch
4
+ import torchvision.models as tv_models
5
+
6
+
7
+ class squeezenet(torch.nn.Module):
8
+ def __init__(self, requires_grad=False, pretrained=True):
9
+ super(squeezenet, self).__init__()
10
+ pretrained_features = tv_models.squeezenet1_1(weights=pretrained).features
11
+ self.slice1 = torch.nn.Sequential()
12
+ self.slice2 = torch.nn.Sequential()
13
+ self.slice3 = torch.nn.Sequential()
14
+ self.slice4 = torch.nn.Sequential()
15
+ self.slice5 = torch.nn.Sequential()
16
+ self.slice6 = torch.nn.Sequential()
17
+ self.slice7 = torch.nn.Sequential()
18
+ self.N_slices = 7
19
+ for x in range(2):
20
+ self.slice1.add_module(str(x), pretrained_features[x])
21
+ for x in range(2, 5):
22
+ self.slice2.add_module(str(x), pretrained_features[x])
23
+ for x in range(5, 8):
24
+ self.slice3.add_module(str(x), pretrained_features[x])
25
+ for x in range(8, 10):
26
+ self.slice4.add_module(str(x), pretrained_features[x])
27
+ for x in range(10, 11):
28
+ self.slice5.add_module(str(x), pretrained_features[x])
29
+ for x in range(11, 12):
30
+ self.slice6.add_module(str(x), pretrained_features[x])
31
+ for x in range(12, 13):
32
+ self.slice7.add_module(str(x), pretrained_features[x])
33
+ if not requires_grad:
34
+ for param in self.parameters():
35
+ param.requires_grad = False
36
+
37
+ def forward(self, X):
38
+ h = self.slice1(X)
39
+ h_relu1 = h
40
+ h = self.slice2(h)
41
+ h_relu2 = h
42
+ h = self.slice3(h)
43
+ h_relu3 = h
44
+ h = self.slice4(h)
45
+ h_relu4 = h
46
+ h = self.slice5(h)
47
+ h_relu5 = h
48
+ h = self.slice6(h)
49
+ h_relu6 = h
50
+ h = self.slice7(h)
51
+ h_relu7 = h
52
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])
53
+ out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
54
+
55
+ return out
56
+
57
+
58
+ class alexnet(torch.nn.Module):
59
+ def __init__(self, requires_grad=False, pretrained=True):
60
+ super(alexnet, self).__init__()
61
+ weights = tv_models.AlexNet_Weights.IMAGENET1K_V1 if pretrained else None
62
+ alexnet_pretrained_features = tv_models.alexnet(weights=weights).features
63
+ self.slice1 = torch.nn.Sequential()
64
+ self.slice2 = torch.nn.Sequential()
65
+ self.slice3 = torch.nn.Sequential()
66
+ self.slice4 = torch.nn.Sequential()
67
+ self.slice5 = torch.nn.Sequential()
68
+ self.N_slices = 5
69
+ for x in range(2):
70
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
71
+ for x in range(2, 5):
72
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
73
+ for x in range(5, 8):
74
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
75
+ for x in range(8, 10):
76
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
77
+ for x in range(10, 12):
78
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
79
+
80
+ if not requires_grad:
81
+ for param in self.parameters():
82
+ param.requires_grad = False
83
+
84
+ def forward(self, X):
85
+ h = self.slice1(X)
86
+ h_relu1 = h
87
+ h = self.slice2(h)
88
+ h_relu2 = h
89
+ h = self.slice3(h)
90
+ h_relu3 = h
91
+ h = self.slice4(h)
92
+ h_relu4 = h
93
+ h = self.slice5(h)
94
+ h_relu5 = h
95
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
96
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
97
+
98
+ return out
99
+
100
+
101
+ class vgg16(torch.nn.Module):
102
+ def __init__(self, requires_grad=False, pretrained=True):
103
+ super(vgg16, self).__init__()
104
+ weights = tv_models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None
105
+ vgg_pretrained_features = tv_models.vgg16(weights=weights).features
106
+ self.slice1 = torch.nn.Sequential()
107
+ self.slice2 = torch.nn.Sequential()
108
+ self.slice3 = torch.nn.Sequential()
109
+ self.slice4 = torch.nn.Sequential()
110
+ self.slice5 = torch.nn.Sequential()
111
+ self.N_slices = 5
112
+ for x in range(4):
113
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
114
+ for x in range(4, 9):
115
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
116
+ for x in range(9, 16):
117
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
118
+ for x in range(16, 23):
119
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
120
+ for x in range(23, 30):
121
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
122
+
123
+ if not requires_grad:
124
+ for param in self.parameters():
125
+ param.requires_grad = False
126
+
127
+ def forward(self, X):
128
+ h = self.slice1(X)
129
+ h_relu1_2 = h
130
+ h = self.slice2(h)
131
+ h_relu2_2 = h
132
+ h = self.slice3(h)
133
+ h_relu3_3 = h
134
+ h = self.slice4(h)
135
+ h_relu4_3 = h
136
+ h = self.slice5(h)
137
+ h_relu5_3 = h
138
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
139
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
140
+
141
+ return out
142
+
143
+
144
+ class resnet(torch.nn.Module):
145
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
146
+ super(resnet, self).__init__()
147
+
148
+ if num == 18:
149
+ weights = tv_models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
150
+ self.net = tv_models.resnet18(weights=weights)
151
+ elif num == 34:
152
+ weights = tv_models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None
153
+ self.net = tv_models.resnet34(weights=weights)
154
+ elif num == 50:
155
+ weights = tv_models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
156
+ self.net = tv_models.resnet50(weights=weights)
157
+ elif num == 101:
158
+ weights = tv_models.ResNet101_Weights.IMAGENET1K_V2 if pretrained else None
159
+ self.net = tv_models.resnet101(weights=weights)
160
+ elif num == 152:
161
+ weights = tv_models.ResNet152_Weights.IMAGENET1K_V2 if pretrained else None
162
+ self.net = tv_models.resnet152(weights=weights)
163
+ self.N_slices = 5
164
+
165
+ if not requires_grad:
166
+ for param in self.net.parameters():
167
+ param.requires_grad = False
168
+
169
+ self.conv1 = self.net.conv1
170
+ self.bn1 = self.net.bn1
171
+ self.relu = self.net.relu
172
+ self.maxpool = self.net.maxpool
173
+ self.layer1 = self.net.layer1
174
+ self.layer2 = self.net.layer2
175
+ self.layer3 = self.net.layer3
176
+ self.layer4 = self.net.layer4
177
+
178
+ def forward(self, X):
179
+ h = self.conv1(X)
180
+ h = self.bn1(h)
181
+ h = self.relu(h)
182
+ h_relu1 = h
183
+ h = self.maxpool(h)
184
+ h = self.layer1(h)
185
+ h_conv2 = h
186
+ h = self.layer2(h)
187
+ h_conv3 = h
188
+ h = self.layer3(h)
189
+ h_conv4 = h
190
+ h = self.layer4(h)
191
+ h_conv5 = h
192
+
193
+ outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
194
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
195
+
196
+ return out
libs/metric/lpips_origin/weights/v0.1/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
3
+ size 6009
libs/metric/lpips_origin/weights/v0.1/squeeze.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
3
+ size 10811
libs/metric/lpips_origin/weights/v0.1/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
libs/metric/piq/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ # install: pip install piq
7
+ # repo: https://github.com/photosynthesis-team/piq
libs/metric/piq/functional/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import ifftshift, get_meshgrid, similarity_map, gradient_map, pow_for_complex, crop_patches
2
+ from .colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq, rgb2lhm
3
+ from .filters import haar_filter, hann_filter, scharr_filter, prewitt_filter, gaussian_filter
4
+ from .filters import binomial_filter1d, average_filter2d
5
+ from .layers import L2Pool2d
6
+ from .resize import imresize
7
+
8
+ __all__ = [
9
+ 'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map', 'pow_for_complex', 'crop_patches',
10
+ 'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq', 'rgb2lhm',
11
+ 'haar_filter', 'hann_filter', 'scharr_filter', 'prewitt_filter', 'gaussian_filter',
12
+ 'binomial_filter1d', 'average_filter2d',
13
+ 'L2Pool2d',
14
+ 'imresize',
15
+ ]
libs/metric/piq/functional/base.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""General purpose functions"""
2
+ from typing import Tuple, Union, Optional
3
+ import torch
4
+ from ..utils import _parse_version
5
+
6
+
7
+ def ifftshift(x: torch.Tensor) -> torch.Tensor:
8
+ r""" Similar to np.fft.ifftshift but applies to PyTorch Tensors"""
9
+ shift = [-(ax // 2) for ax in x.size()]
10
+ return torch.roll(x, shift, tuple(range(len(shift))))
11
+
12
+
13
+ def get_meshgrid(size: Tuple[int, int], device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
14
+ r"""Return coordinate grid matrices centered at zero point.
15
+ Args:
16
+ size: Shape of meshgrid to create
17
+ device: device to use for creation
18
+ dtype: dtype to use for creation
19
+ Returns:
20
+ Meshgrid of size on device with dtype values.
21
+ """
22
+ if size[0] % 2:
23
+ # Odd
24
+ x = torch.arange(-(size[0] - 1) / 2, size[0] / 2, device=device, dtype=dtype) / (size[0] - 1)
25
+ else:
26
+ # Even
27
+ x = torch.arange(- size[0] / 2, size[0] / 2, device=device, dtype=dtype) / size[0]
28
+
29
+ if size[1] % 2:
30
+ # Odd
31
+ y = torch.arange(-(size[1] - 1) / 2, size[1] / 2, device=device, dtype=dtype) / (size[1] - 1)
32
+ else:
33
+ # Even
34
+ y = torch.arange(- size[1] / 2, size[1] / 2, device=device, dtype=dtype) / size[1]
35
+ # Use indexing param depending on torch version
36
+ recommended_torch_version = _parse_version("1.10.0")
37
+ torch_version = _parse_version(torch.__version__)
38
+ if len(torch_version) > 0 and torch_version >= recommended_torch_version:
39
+ return torch.meshgrid(x, y, indexing='ij')
40
+ return torch.meshgrid(x, y)
41
+
42
+
43
+ def similarity_map(map_x: torch.Tensor, map_y: torch.Tensor, constant: float, alpha: float = 0.0) -> torch.Tensor:
44
+ r""" Compute similarity_map between two tensors using Dice-like equation.
45
+
46
+ Args:
47
+ map_x: Tensor with map to be compared
48
+ map_y: Tensor with map to be compared
49
+ constant: Used for numerical stability
50
+ alpha: Masking coefficient. Subtracts - `alpha` * map_x * map_y from denominator and nominator
51
+ """
52
+ return (2.0 * map_x * map_y - alpha * map_x * map_y + constant) / \
53
+ (map_x ** 2 + map_y ** 2 - alpha * map_x * map_y + constant)
54
+
55
+
56
+ def gradient_map(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor:
57
+ r""" Compute gradient map for a given tensor and stack of kernels.
58
+
59
+ Args:
60
+ x: Tensor with shape (N, C, H, W).
61
+ kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W)
62
+ Returns:
63
+ Gradients of x per-channel with shape (N, C, H, W)
64
+ """
65
+ padding = kernels.size(-1) // 2
66
+ grads = torch.nn.functional.conv2d(x, kernels, padding=padding)
67
+
68
+ return torch.sqrt(torch.sum(grads ** 2, dim=-3, keepdim=True))
69
+
70
+
71
+ def pow_for_complex(base: torch.Tensor, exp: Union[int, float]) -> torch.Tensor:
72
+ r""" Takes the power of each element in a 4D tensor with negative values or 5D tensor with complex values.
73
+ Complex numbers are represented by modulus and argument: r * \exp(i * \phi).
74
+
75
+ It will likely to be redundant with introduction of torch.ComplexTensor.
76
+
77
+ Args:
78
+ base: Tensor with shape (N, C, H, W) or (N, C, H, W, 2).
79
+ exp: Exponent
80
+ Returns:
81
+ Complex tensor with shape (N, C, H, W, 2).
82
+ """
83
+ if base.dim() == 4:
84
+ x_complex_r = base.abs()
85
+ x_complex_phi = torch.atan2(torch.zeros_like(base), base)
86
+ elif base.dim() == 5 and base.size(-1) == 2:
87
+ x_complex_r = base.pow(2).sum(dim=-1).sqrt()
88
+ x_complex_phi = torch.atan2(base[..., 1], base[..., 0])
89
+ else:
90
+ raise ValueError(f'Expected real or complex tensor, got {base.size()}')
91
+
92
+ x_complex_pow_r = x_complex_r ** exp
93
+ x_complex_pow_phi = x_complex_phi * exp
94
+ x_real_pow = x_complex_pow_r * torch.cos(x_complex_pow_phi)
95
+ x_imag_pow = x_complex_pow_r * torch.sin(x_complex_pow_phi)
96
+ return torch.stack((x_real_pow, x_imag_pow), dim=-1)
97
+
98
+
99
+ def crop_patches(x: torch.Tensor, size=64, stride=32) -> torch.Tensor:
100
+ r"""Crop tensor with images into small patches
101
+ Args:
102
+ x: Tensor with shape (N, C, H, W), expected to be images-like entities
103
+ size: Size of a square patch
104
+ stride: Step between patches
105
+ """
106
+ assert (x.shape[2] >= size) and (x.shape[3] >= size), \
107
+ f"Images must be bigger than patch size. Got ({x.shape[2], x.shape[3]}) and ({size}, {size})"
108
+ channels = x.shape[1]
109
+ patches = x.unfold(1, channels, channels).unfold(2, size, stride).unfold(3, size, stride)
110
+ patches = patches.reshape(-1, channels, size, size)
111
+ return patches
libs/metric/piq/functional/colour_conversion.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Colour space conversion functions"""
2
+ from typing import Union, Dict
3
+ import torch
4
+
5
+
6
+ def rgb2lmn(x: torch.Tensor) -> torch.Tensor:
7
+ r"""Convert a batch of RGB images to a batch of LMN images
8
+
9
+ Args:
10
+ x: Batch of images with shape (N, 3, H, W). RGB colour space.
11
+
12
+ Returns:
13
+ Batch of images with shape (N, 3, H, W). LMN colour space.
14
+ """
15
+ weights_rgb_to_lmn = torch.tensor([[0.06, 0.63, 0.27],
16
+ [0.30, 0.04, -0.35],
17
+ [0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t()
18
+ x_lmn = torch.matmul(x.permute(0, 2, 3, 1), weights_rgb_to_lmn).permute(0, 3, 1, 2)
19
+ return x_lmn
20
+
21
+
22
+ def rgb2xyz(x: torch.Tensor) -> torch.Tensor:
23
+ r"""Convert a batch of RGB images to a batch of XYZ images
24
+
25
+ Args:
26
+ x: Batch of images with shape (N, 3, H, W). RGB colour space.
27
+
28
+ Returns:
29
+ Batch of images with shape (N, 3, H, W). XYZ colour space.
30
+ """
31
+ mask_below = (x <= 0.04045).type(x.dtype)
32
+ mask_above = (x > 0.04045).type(x.dtype)
33
+
34
+ tmp = x / 12.92 * mask_below + torch.pow((x + 0.055) / 1.055, 2.4) * mask_above
35
+
36
+ weights_rgb_to_xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
37
+ [0.2126729, 0.7151522, 0.0721750],
38
+ [0.0193339, 0.1191920, 0.9503041]], dtype=x.dtype, device=x.device)
39
+
40
+ x_xyz = torch.matmul(tmp.permute(0, 2, 3, 1), weights_rgb_to_xyz.t()).permute(0, 3, 1, 2)
41
+ return x_xyz
42
+
43
+
44
+ def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor:
45
+ r"""Convert a batch of XYZ images to a batch of LAB images
46
+
47
+ Args:
48
+ x: Batch of images with shape (N, 3, H, W). XYZ colour space.
49
+ illuminant: {“A”, “D50”, “D55”, “D65”, “D75”, “E”}, optional. The name of the illuminant.
50
+ observer: {“2”, “10”}, optional. The aperture angle of the observer.
51
+
52
+ Returns:
53
+ Batch of images with shape (N, 3, H, W). LAB colour space.
54
+ """
55
+ epsilon = 0.008856
56
+ kappa = 903.3
57
+ illuminants: Dict[str, Dict] = \
58
+ {"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
59
+ '10': (1.111420406956693, 1, 0.3519978321919493)},
60
+ "D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
61
+ '10': (0.9672062750333777, 1, 0.8142801513128616)},
62
+ "D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
63
+ '10': (0.9579665682254781, 1, 0.9092525159847462)},
64
+ "D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
65
+ '10': (0.94809667673716, 1, 1.0730513595166162)},
66
+ "D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
67
+ '10': (0.9441713925645873, 1, 1.2064272211720228)},
68
+ "E": {'2': (1.0, 1.0, 1.0),
69
+ '10': (1.0, 1.0, 1.0)}}
70
+
71
+ illuminants_to_use = torch.tensor(illuminants[illuminant][observer],
72
+ dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
73
+
74
+ tmp = x / illuminants_to_use
75
+
76
+ mask_below = (tmp <= epsilon).type(x.dtype)
77
+ mask_above = (tmp > epsilon).type(x.dtype)
78
+ tmp = torch.pow(tmp, 1. / 3.) * mask_above + (kappa * tmp + 16.) / 116. * mask_below
79
+
80
+ weights_xyz_to_lab = torch.tensor([[0, 116., 0],
81
+ [500., -500., 0],
82
+ [0, 200., -200.]], dtype=x.dtype, device=x.device)
83
+ bias_xyz_to_lab = torch.tensor([-16., 0., 0.], dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
84
+
85
+ x_lab = torch.matmul(tmp.permute(0, 2, 3, 1), weights_xyz_to_lab.t()).permute(0, 3, 1, 2) + bias_xyz_to_lab
86
+ return x_lab
87
+
88
+
89
+ def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor:
90
+ r"""Convert a batch of RGB images to a batch of LAB images
91
+
92
+ Args:
93
+ x: Batch of images with shape (N, 3, H, W). RGB colour space.
94
+ data_range: dynamic range of the input image.
95
+
96
+ Returns:
97
+ Batch of images with shape (N, 3, H, W). LAB colour space.
98
+ """
99
+ return xyz2lab(rgb2xyz(x / float(data_range)))
100
+
101
+
102
+ def rgb2yiq(x: torch.Tensor) -> torch.Tensor:
103
+ r"""Convert a batch of RGB images to a batch of YIQ images
104
+
105
+ Args:
106
+ x: Batch of images with shape (N, 3, H, W). RGB colour space.
107
+
108
+ Returns:
109
+ Batch of images with shape (N, 3, H, W). YIQ colour space.
110
+ """
111
+ yiq_weights = torch.tensor([
112
+ [0.299, 0.587, 0.114],
113
+ [0.5959, -0.2746, -0.3213],
114
+ [0.2115, -0.5227, 0.3112]], dtype=x.dtype, device=x.device).t()
115
+ x_yiq = torch.matmul(x.permute(0, 2, 3, 1), yiq_weights).permute(0, 3, 1, 2)
116
+ return x_yiq
117
+
118
+
119
+ def rgb2lhm(x: torch.Tensor) -> torch.Tensor:
120
+ r"""Convert a batch of RGB images to a batch of LHM images
121
+
122
+ Args:
123
+ x: Batch of images with shape (N, 3, H, W). RGB colour space.
124
+
125
+ Returns:
126
+ Batch of images with shape (N, 3, H, W). LHM colour space.
127
+
128
+ Reference:
129
+ https://arxiv.org/pdf/1608.07433.pdf
130
+ """
131
+ lhm_weights = torch.tensor([
132
+ [0.2989, 0.587, 0.114],
133
+ [0.3, 0.04, -0.35],
134
+ [0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t()
135
+ x_lhm = torch.matmul(x.permute(0, 2, 3, 1), lhm_weights).permute(0, 3, 1, 2)
136
+ return x_lhm
libs/metric/piq/functional/filters.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Filters for gradient computation, bluring, etc."""
2
+ import torch
3
+ import numpy as np
4
+ from typing import Optional
5
+
6
+
7
+ def haar_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
8
+ r"""Creates Haar kernel
9
+
10
+ Args:
11
+ kernel_size: size of the kernel
12
+ device: target device for kernel generation
13
+ dtype: target data type for kernel generation
14
+ Returns:
15
+ kernel: Tensor with shape (1, kernel_size, kernel_size)
16
+ """
17
+ kernel = torch.ones((kernel_size, kernel_size), device=device, dtype=dtype) / kernel_size
18
+ kernel[kernel_size // 2:, :] = - kernel[kernel_size // 2:, :]
19
+ return kernel.unsqueeze(0)
20
+
21
+
22
+ def hann_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
23
+ r"""Creates Hann kernel
24
+ Args:
25
+ kernel_size: size of the kernel
26
+ device: target device for kernel generation
27
+ dtype: target data type for kernel generation
28
+ Returns:
29
+ kernel: Tensor with shape (1, kernel_size, kernel_size)
30
+ """
31
+ # Take bigger window and drop borders
32
+ window = torch.hann_window(kernel_size + 2, periodic=False, device=device, dtype=dtype)[1:-1]
33
+ kernel = window[:, None] * window[None, :]
34
+ # Normalize and reshape kernel
35
+ return kernel.view(1, kernel_size, kernel_size) / kernel.sum()
36
+
37
+
38
+ def gaussian_filter(kernel_size: int, sigma: float, device: Optional[str] = None,
39
+ dtype: Optional[type] = None) -> torch.Tensor:
40
+ r"""Returns 2D Gaussian kernel N(0,`sigma`^2)
41
+ Args:
42
+ size: Size of the kernel
43
+ sigma: Std of the distribution
44
+ device: target device for kernel generation
45
+ dtype: target data type for kernel generation
46
+ Returns:
47
+ gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size)
48
+ """
49
+ coords = torch.arange(kernel_size, dtype=dtype, device=device)
50
+ coords -= (kernel_size - 1) / 2.
51
+
52
+ g = coords ** 2
53
+ g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp()
54
+
55
+ g /= g.sum()
56
+ return g.unsqueeze(0)
57
+
58
+
59
+ # Gradient operator kernels
60
+ def scharr_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
61
+ r"""Utility function that returns a normalized 3x3 Scharr kernel in X direction
62
+
63
+ Args:
64
+ device: target device for kernel generation
65
+ dtype: target data type for kernel generation
66
+ Returns:
67
+ kernel: Tensor with shape (1, 3, 3)
68
+ """
69
+ return torch.tensor([[[-3., 0., 3.], [-10., 0., 10.], [-3., 0., 3.]]], device=device, dtype=dtype) / 16
70
+
71
+
72
+ def prewitt_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
73
+ r"""Utility function that returns a normalized 3x3 Prewitt kernel in X direction
74
+
75
+ Args:
76
+ device: target device for kernel generation
77
+ dtype: target data type for kernel generation
78
+ Returns:
79
+ kernel: Tensor with shape (1, 3, 3)"""
80
+ return torch.tensor([[[-1., 0., 1.], [-1., 0., 1.], [-1., 0., 1.]]], device=device, dtype=dtype) / 3
81
+
82
+
83
+ def binomial_filter1d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
84
+ r"""Creates 1D normalized binomial filter
85
+
86
+ Args:
87
+ kernel_size (int): kernel size
88
+ device: target device for kernel generation
89
+ dtype: target data type for kernel generation
90
+
91
+ Returns:
92
+ Binomial kernel with shape (1, 1, kernel_size)
93
+ """
94
+ kernel = np.poly1d([0.5, 0.5]) ** (kernel_size - 1)
95
+ return torch.tensor(kernel.c, dtype=dtype, device=device).view(1, 1, kernel_size)
96
+
97
+
98
+ def average_filter2d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
99
+ r"""Creates 2D normalized average filter
100
+
101
+ Args:
102
+ kernel_size (int): kernel size
103
+ device: target device for kernel generation
104
+ dtype: target data type for kernel generation
105
+
106
+ Returns:
107
+ kernel: Tensor with shape (1, kernel_size, kernel_size)
108
+ """
109
+ window = torch.ones(kernel_size, dtype=dtype, device=device) / kernel_size
110
+ kernel = window[:, None] * window[None, :]
111
+ return kernel.unsqueeze(0)
libs/metric/piq/functional/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Custom layers used in metrics computations"""
2
+ import torch
3
+ from typing import Optional
4
+
5
+ from .filters import hann_filter
6
+
7
+
8
+ class L2Pool2d(torch.nn.Module):
9
+ r"""Applies L2 pooling with Hann window of size 3x3
10
+ Args:
11
+ x: Tensor with shape (N, C, H, W)"""
12
+ EPS = 1e-12
13
+
14
+ def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None:
15
+ super().__init__()
16
+ self.kernel_size = kernel_size
17
+ self.stride = stride
18
+ self.padding = padding
19
+
20
+ self.kernel: Optional[torch.Tensor] = None
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ if self.kernel is None:
24
+ C = x.size(1)
25
+ self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x)
26
+
27
+ out = torch.nn.functional.conv2d(
28
+ x ** 2, self.kernel,
29
+ stride=self.stride,
30
+ padding=self.padding,
31
+ groups=x.shape[1]
32
+ )
33
+ return (out + self.EPS).sqrt()
libs/metric/piq/functional/resize.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A standalone PyTorch implementation for fast and efficient bicubic resampling.
3
+ The resulting values are the same to MATLAB function imresize('bicubic').
4
+ ## Author: Sanghyun Son
5
+ ## Email: [email protected] (primary), [email protected] (secondary)
6
+ ## Version: 1.2.0
7
+ ## Last update: July 9th, 2020 (KST)
8
+ Dependency: torch
9
+ Example::
10
+ >>> import torch
11
+ >>> import core
12
+ >>> x = torch.arange(16).float().view(1, 1, 4, 4)
13
+ >>> y = core.imresize(x, sizes=(3, 3))
14
+ >>> print(y)
15
+ tensor([[[[ 0.7506, 2.1004, 3.4503],
16
+ [ 6.1505, 7.5000, 8.8499],
17
+ [11.5497, 12.8996, 14.2494]]]])
18
+ """
19
+
20
+ import math
21
+ import typing
22
+
23
+ import torch
24
+ from torch.nn import functional as F
25
+
26
+ __all__ = ['imresize']
27
+
28
+ _I = typing.Optional[int]
29
+ _D = typing.Optional[torch.dtype]
30
+
31
+
32
+ def nearest_contribution(x: torch.Tensor) -> torch.Tensor:
33
+ range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5))
34
+ cont = range_around_0.to(dtype=x.dtype)
35
+ return cont
36
+
37
+
38
+ def linear_contribution(x: torch.Tensor) -> torch.Tensor:
39
+ ax = x.abs()
40
+ range_01 = ax.le(1)
41
+ cont = (1 - ax) * range_01.to(dtype=x.dtype)
42
+ return cont
43
+
44
+
45
+ def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:
46
+ ax = x.abs()
47
+ ax2 = ax * ax
48
+ ax3 = ax * ax2
49
+
50
+ range_01 = ax.le(1)
51
+ range_12 = torch.logical_and(ax.gt(1), ax.le(2))
52
+
53
+ cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
54
+ cont_01 = cont_01 * range_01.to(dtype=x.dtype)
55
+
56
+ cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
57
+ cont_12 = cont_12 * range_12.to(dtype=x.dtype)
58
+
59
+ cont = cont_01 + cont_12
60
+ return cont
61
+
62
+
63
+ def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
64
+ range_3sigma = (x.abs() <= 3 * sigma + 1)
65
+ # Normalization will be done after
66
+ cont = torch.exp(-x.pow(2) / (2 * sigma ** 2))
67
+ cont = cont * range_3sigma.to(dtype=x.dtype)
68
+ return cont
69
+
70
+
71
+ def discrete_kernel(
72
+ kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor:
73
+ '''
74
+ For downsampling with integer scale only.
75
+ '''
76
+ downsampling_factor = int(1 / scale)
77
+ if kernel == 'cubic':
78
+ kernel_size_orig = 4
79
+ else:
80
+ raise ValueError('Pass!')
81
+
82
+ if antialiasing:
83
+ kernel_size = kernel_size_orig * downsampling_factor
84
+ else:
85
+ kernel_size = kernel_size_orig
86
+
87
+ if downsampling_factor % 2 == 0:
88
+ a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))
89
+ else:
90
+ kernel_size -= 1
91
+ a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))
92
+
93
+ with torch.no_grad():
94
+ r = torch.linspace(-a, a, steps=kernel_size)
95
+ k = cubic_contribution(r).view(-1, 1)
96
+ k = torch.matmul(k, k.t())
97
+ k /= k.sum()
98
+
99
+ return k
100
+
101
+
102
+ def reflect_padding(
103
+ x: torch.Tensor,
104
+ dim: int,
105
+ pad_pre: int,
106
+ pad_post: int) -> torch.Tensor:
107
+ '''
108
+ Apply reflect padding to the given Tensor.
109
+ Note that it is slightly different from the PyTorch functional.pad,
110
+ where boundary elements are used only once.
111
+ Instead, we follow the MATLAB implementation
112
+ which uses boundary elements twice.
113
+ For example,
114
+ [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
115
+ while our implementation yields [a, a, b, c, d, d].
116
+ '''
117
+ b, c, h, w = x.size()
118
+ if dim == 2 or dim == -2:
119
+ padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)
120
+ padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)
121
+ for p in range(pad_pre):
122
+ padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])
123
+ for p in range(pad_post):
124
+ padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])
125
+ else:
126
+ padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)
127
+ padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)
128
+ for p in range(pad_pre):
129
+ padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])
130
+ for p in range(pad_post):
131
+ padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])
132
+
133
+ return padding_buffer
134
+
135
+
136
+ def padding(
137
+ x: torch.Tensor,
138
+ dim: int,
139
+ pad_pre: int,
140
+ pad_post: int,
141
+ padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:
142
+ if padding_type is None:
143
+ return x
144
+ elif padding_type == 'reflect':
145
+ x_pad = reflect_padding(x, dim, pad_pre, pad_post)
146
+ else:
147
+ raise ValueError('{} padding is not supported!'.format(padding_type))
148
+
149
+ return x_pad
150
+
151
+
152
+ def get_padding(
153
+ base: torch.Tensor,
154
+ kernel_size: int,
155
+ x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
156
+ base = base.long()
157
+ r_min = base.min()
158
+ r_max = base.max() + kernel_size - 1
159
+
160
+ if r_min <= 0:
161
+ pad_pre = -r_min
162
+ pad_pre = pad_pre.item()
163
+ base += pad_pre
164
+ else:
165
+ pad_pre = 0
166
+
167
+ if r_max >= x_size:
168
+ pad_post = r_max - x_size + 1
169
+ pad_post = pad_post.item()
170
+ else:
171
+ pad_post = 0
172
+
173
+ return pad_pre, pad_post, base
174
+
175
+
176
+ def get_weight(
177
+ dist: torch.Tensor,
178
+ kernel_size: int,
179
+ kernel: str = 'cubic',
180
+ sigma: float = 2.0,
181
+ antialiasing_factor: float = 1) -> torch.Tensor:
182
+ buffer_pos = dist.new_zeros(kernel_size, len(dist))
183
+ for idx, buffer_sub in enumerate(buffer_pos):
184
+ buffer_sub.copy_(dist - idx)
185
+
186
+ # Expand (downsampling) / Shrink (upsampling) the receptive field.
187
+ buffer_pos *= antialiasing_factor
188
+ if kernel == 'cubic':
189
+ weight = cubic_contribution(buffer_pos)
190
+ elif kernel == 'gaussian':
191
+ weight = gaussian_contribution(buffer_pos, sigma=sigma)
192
+ else:
193
+ raise ValueError('{} kernel is not supported!'.format(kernel))
194
+
195
+ weight /= weight.sum(dim=0, keepdim=True)
196
+ return weight
197
+
198
+
199
+ def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
200
+ # Resize height
201
+ if dim == 2 or dim == -2:
202
+ k = (kernel_size, 1)
203
+ h_out = x.size(-2) - kernel_size + 1
204
+ w_out = x.size(-1)
205
+ # Resize width
206
+ else:
207
+ k = (1, kernel_size)
208
+ h_out = x.size(-2)
209
+ w_out = x.size(-1) - kernel_size + 1
210
+
211
+ unfold = F.unfold(x, k)
212
+ unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
213
+ return unfold
214
+
215
+
216
+ def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]:
217
+ if x.dim() == 4:
218
+ b, c, h, w = x.size()
219
+ elif x.dim() == 3:
220
+ c, h, w = x.size()
221
+ b = None
222
+ elif x.dim() == 2:
223
+ h, w = x.size()
224
+ b = c = None
225
+ else:
226
+ raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))
227
+
228
+ x = x.view(-1, 1, h, w)
229
+ return x, b, c, h, w
230
+
231
+
232
+ def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
233
+ rh = x.size(-2)
234
+ rw = x.size(-1)
235
+ # Back to the original dimension
236
+ if b is not None:
237
+ x = x.view(b, c, rh, rw) # 4-dim
238
+ else:
239
+ if c is not None:
240
+ x = x.view(c, rh, rw) # 3-dim
241
+ else:
242
+ x = x.view(rh, rw) # 2-dim
243
+
244
+ return x
245
+
246
+
247
+ def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
248
+ if x.dtype != torch.float32 or x.dtype != torch.float64:
249
+ dtype = x.dtype
250
+ x = x.float()
251
+ else:
252
+ dtype = None
253
+
254
+ return x, dtype
255
+
256
+
257
+ def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:
258
+ if dtype is not None:
259
+ if not dtype.is_floating_point:
260
+ x = x.round()
261
+ # To prevent over/underflow when converting types
262
+ if dtype is torch.uint8:
263
+ x = x.clamp(0, 255)
264
+
265
+ x = x.to(dtype=dtype)
266
+
267
+ return x
268
+
269
+
270
+ def resize_1d(
271
+ x: torch.Tensor,
272
+ dim: int,
273
+ size: int,
274
+ scale: float,
275
+ kernel: str = 'cubic',
276
+ sigma: float = 2.0,
277
+ padding_type: str = 'reflect',
278
+ antialiasing: bool = True) -> torch.Tensor:
279
+ '''
280
+ Args:
281
+ x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
282
+ dim (int):
283
+ scale (float):
284
+ size (int):
285
+ Return:
286
+ '''
287
+ # Identity case
288
+ if scale == 1:
289
+ return x
290
+
291
+ # Default bicubic kernel with antialiasing (only when downsampling)
292
+ if kernel == 'cubic':
293
+ kernel_size = 4
294
+ else:
295
+ kernel_size = math.floor(6 * sigma)
296
+
297
+ if antialiasing and (scale < 1):
298
+ antialiasing_factor = scale
299
+ kernel_size = math.ceil(kernel_size / antialiasing_factor)
300
+ else:
301
+ antialiasing_factor = 1
302
+
303
+ # We allow margin to both sizes
304
+ kernel_size += 2
305
+
306
+ # Weights only depend on the shape of input and output,
307
+ # so we do not calculate gradients here.
308
+ with torch.no_grad():
309
+ pos = torch.linspace(
310
+ 0, size - 1, steps=size, dtype=x.dtype, device=x.device,
311
+ )
312
+ pos = (pos + 0.5) / scale - 0.5
313
+ base = pos.floor() - (kernel_size // 2) + 1
314
+ dist = pos - base
315
+ weight = get_weight(
316
+ dist,
317
+ kernel_size,
318
+ kernel=kernel,
319
+ sigma=sigma,
320
+ antialiasing_factor=antialiasing_factor,
321
+ )
322
+ pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))
323
+
324
+ # To backpropagate through x
325
+ x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)
326
+ unfold = reshape_tensor(x_pad, dim, kernel_size)
327
+ # Subsampling first
328
+ if dim == 2 or dim == -2:
329
+ sample = unfold[..., base, :]
330
+ weight = weight.view(1, kernel_size, sample.size(2), 1)
331
+ else:
332
+ sample = unfold[..., base]
333
+ weight = weight.view(1, kernel_size, 1, sample.size(3))
334
+
335
+ # Apply the kernel
336
+ x = sample * weight
337
+ x = x.sum(dim=1, keepdim=True)
338
+ return x
339
+
340
+
341
+ def downsampling_2d(
342
+ x: torch.Tensor,
343
+ k: torch.Tensor,
344
+ scale: int,
345
+ padding_type: str = 'reflect') -> torch.Tensor:
346
+ c = x.size(1)
347
+ k_h = k.size(-2)
348
+ k_w = k.size(-1)
349
+
350
+ k = k.to(dtype=x.dtype, device=x.device)
351
+ k = k.view(1, 1, k_h, k_w)
352
+ k = k.repeat(c, c, 1, 1)
353
+ e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
354
+ e = e.view(c, c, 1, 1)
355
+ k = k * e
356
+
357
+ pad_h = (k_h - scale) // 2
358
+ pad_w = (k_w - scale) // 2
359
+ x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)
360
+ x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)
361
+ y = F.conv2d(x, k, padding=0, stride=scale)
362
+ return y
363
+
364
+
365
+ def imresize(
366
+ x: torch.Tensor,
367
+ scale: typing.Optional[float] = None,
368
+ sizes: typing.Optional[typing.Tuple[int, int]] = None,
369
+ kernel: typing.Union[str, torch.Tensor] = 'cubic',
370
+ sigma: float = 2,
371
+ rotation_degree: float = 0,
372
+ padding_type: str = 'reflect',
373
+ antialiasing: bool = True) -> torch.Tensor:
374
+ """
375
+ Args:
376
+ x (torch.Tensor):
377
+ scale (float):
378
+ sizes (tuple(int, int)):
379
+ kernel (str, default='cubic'):
380
+ sigma (float, default=2):
381
+ rotation_degree (float, default=0):
382
+ padding_type (str, default='reflect'):
383
+ antialiasing (bool, default=True):
384
+ Return:
385
+ torch.Tensor:
386
+ """
387
+ if scale is None and sizes is None:
388
+ raise ValueError('One of scale or sizes must be specified!')
389
+ if scale is not None and sizes is not None:
390
+ raise ValueError('Please specify scale or sizes to avoid conflict!')
391
+
392
+ x, b, c, h, w = reshape_input(x)
393
+
394
+ if sizes is None and scale is not None:
395
+ '''
396
+ # Check if we can apply the convolution algorithm
397
+ scale_inv = 1 / scale
398
+ if isinstance(kernel, str) and scale_inv.is_integer():
399
+ kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
400
+ elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
401
+ raise ValueError(
402
+ 'An integer downsampling factor '
403
+ 'should be used with a predefined kernel!'
404
+ )
405
+ '''
406
+ # Determine output size
407
+ sizes = (math.ceil(h * scale), math.ceil(w * scale))
408
+ scales = (scale, scale)
409
+
410
+ if scale is None and sizes is not None:
411
+ scales = (sizes[0] / h, sizes[1] / w)
412
+
413
+ x, dtype = cast_input(x)
414
+
415
+ if isinstance(kernel, str) and sizes is not None:
416
+ # Core resizing module
417
+ x = resize_1d(x, -2, size=sizes[0], scale=scales[0], kernel=kernel, sigma=sigma, padding_type=padding_type,
418
+ antialiasing=antialiasing)
419
+ x = resize_1d(x, -1, size=sizes[1], scale=scales[1], kernel=kernel, sigma=sigma, padding_type=padding_type,
420
+ antialiasing=antialiasing)
421
+ elif isinstance(kernel, torch.Tensor) and scale is not None:
422
+ x = downsampling_2d(x, kernel, scale=int(1 / scale))
423
+
424
+ x = reshape_output(x, b, c)
425
+ x = cast_output(x, dtype)
426
+ return x
libs/metric/piq/perceptual.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of Content loss, Style loss, LPIPS and DISTS metrics
3
+ References:
4
+ .. [1] Gatys, Leon and Ecker, Alexander and Bethge, Matthias
5
+ (2016). A Neural Algorithm of Artistic Style}
6
+ Association for Research in Vision and Ophthalmology (ARVO)
7
+ https://arxiv.org/abs/1508.06576
8
+ .. [2] Zhang, Richard and Isola, Phillip and Efros, et al.
9
+ (2018) The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
10
+ 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition
11
+ https://arxiv.org/abs/1801.03924
12
+ """
13
+ from typing import List, Union, Collection
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.nn.modules.loss import _Loss
18
+ from torchvision.models import vgg16, vgg19, VGG16_Weights, VGG19_Weights
19
+
20
+ from .utils import _validate_input, _reduce
21
+ from .functional import similarity_map, L2Pool2d
22
+
23
+ # Map VGG names to corresponding number in torchvision layer
24
+ VGG16_LAYERS = {
25
+ "conv1_1": '0', "relu1_1": '1',
26
+ "conv1_2": '2', "relu1_2": '3',
27
+ "pool1": '4',
28
+ "conv2_1": '5', "relu2_1": '6',
29
+ "conv2_2": '7', "relu2_2": '8',
30
+ "pool2": '9',
31
+ "conv3_1": '10', "relu3_1": '11',
32
+ "conv3_2": '12', "relu3_2": '13',
33
+ "conv3_3": '14', "relu3_3": '15',
34
+ "pool3": '16',
35
+ "conv4_1": '17', "relu4_1": '18',
36
+ "conv4_2": '19', "relu4_2": '20',
37
+ "conv4_3": '21', "relu4_3": '22',
38
+ "pool4": '23',
39
+ "conv5_1": '24', "relu5_1": '25',
40
+ "conv5_2": '26', "relu5_2": '27',
41
+ "conv5_3": '28', "relu5_3": '29',
42
+ "pool5": '30',
43
+ }
44
+
45
+ VGG19_LAYERS = {
46
+ "conv1_1": '0', "relu1_1": '1',
47
+ "conv1_2": '2', "relu1_2": '3',
48
+ "pool1": '4',
49
+ "conv2_1": '5', "relu2_1": '6',
50
+ "conv2_2": '7', "relu2_2": '8',
51
+ "pool2": '9',
52
+ "conv3_1": '10', "relu3_1": '11',
53
+ "conv3_2": '12', "relu3_2": '13',
54
+ "conv3_3": '14', "relu3_3": '15',
55
+ "conv3_4": '16', "relu3_4": '17',
56
+ "pool3": '18',
57
+ "conv4_1": '19', "relu4_1": '20',
58
+ "conv4_2": '21', "relu4_2": '22',
59
+ "conv4_3": '23', "relu4_3": '24',
60
+ "conv4_4": '25', "relu4_4": '26',
61
+ "pool4": '27',
62
+ "conv5_1": '28', "relu5_1": '29',
63
+ "conv5_2": '30', "relu5_2": '31',
64
+ "conv5_3": '32', "relu5_3": '33',
65
+ "conv5_4": '34', "relu5_4": '35',
66
+ "pool5": '36',
67
+ }
68
+
69
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
70
+ IMAGENET_STD = [0.229, 0.224, 0.225]
71
+
72
+ # Constant used in feature normalization to avoid zero division
73
+ EPS = 1e-10
74
+
75
+
76
+ class ContentLoss(_Loss):
77
+ r"""Creates Content loss that can be used for image style transfer or as a measure for image to image tasks.
78
+ Uses pretrained VGG models from torchvision.
79
+ Expects input to be in range [0, 1] or normalized with ImageNet statistics into range [-1, 1]
80
+
81
+ Args:
82
+ feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``.
83
+ layers: List of strings with layer names. Default: ``'relu3_3'``
84
+ weights: List of float weight to balance different layers
85
+ replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
86
+ distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
87
+ reduction: Specifies the reduction type:
88
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
89
+ mean: List of float values used for data standardization. Default: ImageNet mean.
90
+ If there is no need to normalize data, use [0., 0., 0.].
91
+ std: List of float values used for data standardization. Default: ImageNet std.
92
+ If there is no need to normalize data, use [1., 1., 1.].
93
+ normalize_features: If true, unit-normalize each feature in channel dimension before scaling
94
+ and computing distance. See references for details.
95
+
96
+ Examples:
97
+ >>> loss = ContentLoss()
98
+ >>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
99
+ >>> y = torch.rand(3, 3, 256, 256)
100
+ >>> output = loss(x, y)
101
+ >>> output.backward()
102
+
103
+ References:
104
+ Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
105
+ A Neural Algorithm of Artistic Style
106
+ Association for Research in Vision and Ophthalmology (ARVO)
107
+ https://arxiv.org/abs/1508.06576
108
+
109
+ Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
110
+ The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
111
+ IEEE/CVF Conference on Computer Vision and Pattern Recognition
112
+ https://arxiv.org/abs/1801.03924
113
+ """
114
+
115
+ def __init__(self, feature_extractor: Union[str, torch.nn.Module] = "vgg16", layers: Collection[str] = ("relu3_3",),
116
+ weights: List[Union[float, torch.Tensor]] = [1.], replace_pooling: bool = False,
117
+ distance: str = "mse", reduction: str = "mean", mean: List[float] = IMAGENET_MEAN,
118
+ std: List[float] = IMAGENET_STD, normalize_features: bool = False,
119
+ allow_layers_weights_mismatch: bool = False) -> None:
120
+
121
+ assert allow_layers_weights_mismatch or len(layers) == len(weights), \
122
+ f'Lengths of provided layers and weighs mismatch ({len(weights)} weights and {len(layers)} layers), ' \
123
+ f'which will cause incorrect results. Please provide weight for each layer.'
124
+
125
+ super().__init__()
126
+
127
+ if callable(feature_extractor):
128
+ self.model = feature_extractor
129
+ self.layers = layers
130
+ else:
131
+ if feature_extractor == "vgg16":
132
+ # self.model = vgg16(pretrained=True, progress=False).features
133
+ self.model = vgg16(weights=VGG16_Weights.DEFAULT, progress=False).features
134
+ self.layers = [VGG16_LAYERS[l] for l in layers]
135
+ elif feature_extractor == "vgg19":
136
+ # self.model = vgg19(pretrained=True, progress=False).features
137
+ self.model = vgg19(weights=VGG19_Weights.DEFAULT, progress=False).features
138
+ self.layers = [VGG19_LAYERS[l] for l in layers]
139
+ else:
140
+ raise ValueError("Unknown feature extractor")
141
+
142
+ if replace_pooling:
143
+ self.model = self.replace_pooling(self.model)
144
+
145
+ # Disable gradients
146
+ for param in self.model.parameters():
147
+ param.requires_grad_(False)
148
+
149
+ self.distance = {
150
+ "mse": nn.MSELoss,
151
+ "mae": nn.L1Loss,
152
+ }[distance](reduction='none')
153
+
154
+ self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in weights]
155
+
156
+ mean = torch.tensor(mean)
157
+ std = torch.tensor(std)
158
+ self.mean = mean.view(1, -1, 1, 1)
159
+ self.std = std.view(1, -1, 1, 1)
160
+
161
+ self.normalize_features = normalize_features
162
+ self.reduction = reduction
163
+
164
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
165
+ r"""Computation of Content loss between feature representations of prediction :math:`x` and
166
+ target :math:`y` tensors.
167
+
168
+ Args:
169
+ x: An input tensor. Shape :math:`(N, C, H, W)`.
170
+ y: A target tensor. Shape :math:`(N, C, H, W)`.
171
+
172
+ Returns:
173
+ Content loss between feature representations
174
+ """
175
+ _validate_input([x, y], dim_range=(4, 4), data_range=(0, -1))
176
+
177
+ self.model.to(x)
178
+ x_features = self.get_features(x)
179
+ y_features = self.get_features(y)
180
+
181
+ distances = self.compute_distance(x_features, y_features)
182
+
183
+ # Scale distances, then average in spatial dimensions, then stack and sum in channels dimension
184
+ loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1)
185
+
186
+ return _reduce(loss, self.reduction)
187
+
188
+ def compute_distance(self, x_features: List[torch.Tensor], y_features: List[torch.Tensor]) -> List[torch.Tensor]:
189
+ r"""Take L2 or L1 distance between feature maps depending on ``distance``.
190
+
191
+ Args:
192
+ x_features: Features of the input tensor.
193
+ y_features: Features of the target tensor.
194
+
195
+ Returns:
196
+ Distance between feature maps
197
+ """
198
+ return [self.distance(x, y) for x, y in zip(x_features, y_features)]
199
+
200
+ def get_features(self, x: torch.Tensor) -> List[torch.Tensor]:
201
+ r"""
202
+ Args:
203
+ x: Tensor. Shape :math:`(N, C, H, W)`.
204
+
205
+ Returns:
206
+ List of features extracted from intermediate layers
207
+ """
208
+ # Normalize input
209
+ x = (x - self.mean.to(x)) / self.std.to(x)
210
+
211
+ features = []
212
+ for name, module in self.model._modules.items():
213
+ x = module(x)
214
+ if name in self.layers:
215
+ features.append(self.normalize(x) if self.normalize_features else x)
216
+
217
+ return features
218
+
219
+ @staticmethod
220
+ def normalize(x: torch.Tensor) -> torch.Tensor:
221
+ r"""Normalize feature maps in channel direction to unit length.
222
+
223
+ Args:
224
+ x: Tensor. Shape :math:`(N, C, H, W)`.
225
+
226
+ Returns:
227
+ Normalized input
228
+ """
229
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
230
+ return x / (norm_factor + EPS)
231
+
232
+ def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
233
+ r"""Turn All MaxPool layers into AveragePool
234
+
235
+ Args:
236
+ module: Module to change MaxPool int AveragePool
237
+
238
+ Returns:
239
+ Module with AveragePool instead MaxPool
240
+
241
+ """
242
+ module_output = module
243
+ if isinstance(module, torch.nn.MaxPool2d):
244
+ module_output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
245
+
246
+ for name, child in module.named_children():
247
+ module_output.add_module(name, self.replace_pooling(child))
248
+ return module_output
249
+
250
+
251
+ class StyleLoss(ContentLoss):
252
+ r"""Creates Style loss that can be used for image style transfer or as a measure in
253
+ image to image tasks. Computes distance between Gram matrices of feature maps.
254
+ Uses pretrained VGG models from torchvision.
255
+
256
+ By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
257
+ If no normalisation is required, change `mean` and `std` values accordingly.
258
+
259
+ Args:
260
+ feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``.
261
+ layers: List of strings with layer names. Default: ``'relu3_3'``
262
+ weights: List of float weight to balance different layers
263
+ replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
264
+ distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
265
+ reduction: Specifies the reduction type:
266
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
267
+ mean: List of float values used for data standardization. Default: ImageNet mean.
268
+ If there is no need to normalize data, use [0., 0., 0.].
269
+ std: List of float values used for data standardization. Default: ImageNet std.
270
+ If there is no need to normalize data, use [1., 1., 1.].
271
+ normalize_features: If true, unit-normalize each feature in channel dimension before scaling
272
+ and computing distance. See references for details.
273
+
274
+ Examples:
275
+ >>> loss = StyleLoss()
276
+ >>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
277
+ >>> y = torch.rand(3, 3, 256, 256)
278
+ >>> output = loss(x, y)
279
+ >>> output.backward()
280
+
281
+ References:
282
+ Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
283
+ A Neural Algorithm of Artistic Style
284
+ Association for Research in Vision and Ophthalmology (ARVO)
285
+ https://arxiv.org/abs/1508.06576
286
+
287
+ Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
288
+ The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
289
+ IEEE/CVF Conference on Computer Vision and Pattern Recognition
290
+ https://arxiv.org/abs/1801.03924
291
+ """
292
+
293
+ def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor):
294
+ r"""Take L2 or L1 distance between Gram matrices of feature maps depending on ``distance``.
295
+
296
+ Args:
297
+ x_features: Features of the input tensor.
298
+ y_features: Features of the target tensor.
299
+
300
+ Returns:
301
+ Distance between Gram matrices
302
+ """
303
+ x_gram = [self.gram_matrix(x) for x in x_features]
304
+ y_gram = [self.gram_matrix(x) for x in y_features]
305
+ return [self.distance(x, y) for x, y in zip(x_gram, y_gram)]
306
+
307
+ @staticmethod
308
+ def gram_matrix(x: torch.Tensor) -> torch.Tensor:
309
+ r"""Compute Gram matrix for batch of features.
310
+
311
+ Args:
312
+ x: Tensor. Shape :math:`(N, C, H, W)`.
313
+
314
+ Returns:
315
+ Gram matrix for given input
316
+ """
317
+ B, C, H, W = x.size()
318
+ gram = []
319
+ for i in range(B):
320
+ features = x[i].view(C, H * W)
321
+
322
+ # Add fake channel dimension
323
+ gram.append(torch.mm(features, features.t()).unsqueeze(0))
324
+
325
+ return torch.stack(gram)
326
+
327
+
328
+ class LPIPS(ContentLoss):
329
+ r"""Learned Perceptual Image Patch Similarity metric. Only VGG16 learned weights are supported.
330
+
331
+ By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
332
+ If no normalisation is required, change `mean` and `std` values accordingly.
333
+
334
+ Args:
335
+ replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
336
+ distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
337
+ reduction: Specifies the reduction type:
338
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
339
+ mean: List of float values used for data standardization. Default: ImageNet mean.
340
+ If there is no need to normalize data, use [0., 0., 0.].
341
+ std: List of float values used for data standardization. Default: ImageNet std.
342
+ If there is no need to normalize data, use [1., 1., 1.].
343
+
344
+ Examples:
345
+ >>> loss = LPIPS()
346
+ >>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
347
+ >>> y = torch.rand(3, 3, 256, 256)
348
+ >>> output = loss(x, y)
349
+ >>> output.backward()
350
+
351
+ References:
352
+ Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
353
+ A Neural Algorithm of Artistic Style
354
+ Association for Research in Vision and Ophthalmology (ARVO)
355
+ https://arxiv.org/abs/1508.06576
356
+
357
+ Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
358
+ The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
359
+ IEEE/CVF Conference on Computer Vision and Pattern Recognition
360
+ https://arxiv.org/abs/1801.03924
361
+ https://github.com/richzhang/PerceptualSimilarity
362
+ """
363
+ _weights_url = "https://github.com/photosynthesis-team/" + \
364
+ "photosynthesis.metrics/releases/download/v0.4.0/lpips_weights.pt"
365
+
366
+ def __init__(self, replace_pooling: bool = False, distance: str = "mse", reduction: str = "mean",
367
+ mean: List[float] = IMAGENET_MEAN, std: List[float] = IMAGENET_STD, ) -> None:
368
+ lpips_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
369
+ lpips_weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
370
+ super().__init__("vgg16", layers=lpips_layers, weights=lpips_weights,
371
+ replace_pooling=replace_pooling, distance=distance,
372
+ reduction=reduction, mean=mean, std=std,
373
+ normalize_features=True)
374
+
375
+
376
+ class DISTS(ContentLoss):
377
+ r"""Deep Image Structure and Texture Similarity metric.
378
+
379
+ By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
380
+ If no normalisation is required, change `mean` and `std` values accordingly.
381
+
382
+ Args:
383
+ reduction: Specifies the reduction type:
384
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
385
+ mean: List of float values used for data standardization. Default: ImageNet mean.
386
+ If there is no need to normalize data, use [0., 0., 0.].
387
+ std: List of float values used for data standardization. Default: ImageNet std.
388
+ If there is no need to normalize data, use [1., 1., 1.].
389
+
390
+ Examples:
391
+ >>> loss = DISTS()
392
+ >>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
393
+ >>> y = torch.rand(3, 3, 256, 256)
394
+ >>> output = loss(x, y)
395
+ >>> output.backward()
396
+
397
+ References:
398
+ Keyan Ding, Kede Ma, Shiqi Wang, Eero P. Simoncelli (2020).
399
+ Image Quality Assessment: Unifying Structure and Texture Similarity.
400
+ https://arxiv.org/abs/2004.07728
401
+ https://github.com/dingkeyan93/DISTS
402
+ """
403
+ _weights_url = "https://github.com/photosynthesis-team/piq/releases/download/v0.4.1/dists_weights.pt"
404
+
405
+ def __init__(self, reduction: str = "mean", mean: List[float] = IMAGENET_MEAN,
406
+ std: List[float] = IMAGENET_STD) -> None:
407
+ dists_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
408
+ channels = [3, 64, 128, 256, 512, 512]
409
+
410
+ weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
411
+ dists_weights = list(torch.split(weights['alpha'], channels, dim=1))
412
+ dists_weights.extend(torch.split(weights['beta'], channels, dim=1))
413
+
414
+ super().__init__("vgg16", layers=dists_layers, weights=dists_weights,
415
+ replace_pooling=True, reduction=reduction, mean=mean, std=std,
416
+ normalize_features=False, allow_layers_weights_mismatch=True)
417
+
418
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
419
+ r"""
420
+
421
+ Args:
422
+ x: An input tensor. Shape :math:`(N, C, H, W)`.
423
+ y: A target tensor. Shape :math:`(N, C, H, W)`.
424
+
425
+ Returns:
426
+ Deep Image Structure and Texture Similarity loss, i.e. ``1-DISTS`` in range [0, 1].
427
+ """
428
+ _, _, H, W = x.shape
429
+
430
+ if min(H, W) > 256:
431
+ x = torch.nn.functional.interpolate(
432
+ x, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
433
+ y = torch.nn.functional.interpolate(
434
+ y, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
435
+
436
+ loss = super().forward(x, y)
437
+ return 1 - loss
438
+
439
+ def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor) -> List[torch.Tensor]:
440
+ r"""Compute structure similarity between feature maps
441
+
442
+ Args:
443
+ x_features: Features of the input tensor.
444
+ y_features: Features of the target tensor.
445
+
446
+ Returns:
447
+ Structural similarity distance between feature maps
448
+ """
449
+ structure_distance, texture_distance = [], []
450
+ # Small constant for numerical stability
451
+ EPS = 1e-6
452
+
453
+ for x, y in zip(x_features, y_features):
454
+ x_mean = x.mean([2, 3], keepdim=True)
455
+ y_mean = y.mean([2, 3], keepdim=True)
456
+ structure_distance.append(similarity_map(x_mean, y_mean, constant=EPS))
457
+
458
+ x_var = ((x - x_mean) ** 2).mean([2, 3], keepdim=True)
459
+ y_var = ((y - y_mean) ** 2).mean([2, 3], keepdim=True)
460
+ xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean
461
+ texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS))
462
+
463
+ return structure_distance + texture_distance
464
+
465
+ def get_features(self, x: torch.Tensor) -> List[torch.Tensor]:
466
+ r"""
467
+
468
+ Args:
469
+ x: Input tensor
470
+
471
+ Returns:
472
+ List of features extracted from input tensor
473
+ """
474
+ features = super().get_features(x)
475
+
476
+ # Add input tensor as an additional feature
477
+ features.insert(0, x)
478
+ return features
479
+
480
+ def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
481
+ r"""Turn All MaxPool layers into L2Pool
482
+
483
+ Args:
484
+ module: Module to change MaxPool into L2Pool
485
+
486
+ Returns:
487
+ Module with L2Pool instead of MaxPool
488
+ """
489
+ module_output = module
490
+ if isinstance(module, torch.nn.MaxPool2d):
491
+ module_output = L2Pool2d(kernel_size=3, stride=2, padding=1)
492
+
493
+ for name, child in module.named_children():
494
+ module_output.add_module(name, self.replace_pooling(child))
495
+
496
+ return module_output
libs/metric/piq/utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .common import _validate_input, _reduce, _parse_version
2
+
3
+ __all__ = [
4
+ "_validate_input",
5
+ "_reduce",
6
+ '_parse_version'
7
+ ]
libs/metric/piq/utils/common.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import warnings
4
+
5
+ from typing import Tuple, List, Optional, Union, Dict, Any
6
+
7
+ SEMVER_VERSION_PATTERN = re.compile(
8
+ r"""
9
+ ^
10
+ (?P<major>0|[1-9]\d*)
11
+ \.
12
+ (?P<minor>0|[1-9]\d*)
13
+ \.
14
+ (?P<patch>0|[1-9]\d*)
15
+ (?:-(?P<prerelease>
16
+ (?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)
17
+ (?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*
18
+ ))?
19
+ (?:\+(?P<build>
20
+ [0-9a-zA-Z-]+
21
+ (?:\.[0-9a-zA-Z-]+)*
22
+ ))?
23
+ $
24
+ """,
25
+ re.VERBOSE,
26
+ )
27
+
28
+
29
+ PEP_440_VERSION_PATTERN = r"""
30
+ v?
31
+ (?:
32
+ (?:(?P<epoch>[0-9]+)!)? # epoch
33
+ (?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
34
+ (?P<pre> # pre-release
35
+ [-_\.]?
36
+ (?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
37
+ [-_\.]?
38
+ (?P<pre_n>[0-9]+)?
39
+ )?
40
+ (?P<post> # post release
41
+ (?:-(?P<post_n1>[0-9]+))
42
+ |
43
+ (?:
44
+ [-_\.]?
45
+ (?P<post_l>post|rev|r)
46
+ [-_\.]?
47
+ (?P<post_n2>[0-9]+)?
48
+ )
49
+ )?
50
+ (?P<dev> # dev release
51
+ [-_\.]?
52
+ (?P<dev_l>dev)
53
+ [-_\.]?
54
+ (?P<dev_n>[0-9]+)?
55
+ )?
56
+ )
57
+ (?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
58
+ """
59
+
60
+
61
+ def _validate_input(
62
+ tensors: List[torch.Tensor],
63
+ dim_range: Tuple[int, int] = (0, -1),
64
+ data_range: Tuple[float, float] = (0., -1.),
65
+ # size_dim_range: Tuple[float, float] = (0., -1.),
66
+ size_range: Optional[Tuple[int, int]] = None,
67
+ ) -> None:
68
+ r"""Check that input(-s) satisfies the requirements
69
+ Args:
70
+ tensors: Tensors to check
71
+ dim_range: Allowed number of dimensions. (min, max)
72
+ data_range: Allowed range of values in tensors. (min, max)
73
+ size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
74
+ """
75
+
76
+ if not __debug__:
77
+ return
78
+
79
+ x = tensors[0]
80
+
81
+ for t in tensors:
82
+ assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
83
+ assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
84
+
85
+ if size_range is None:
86
+ assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
87
+ else:
88
+ assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
89
+ f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
90
+
91
+ if dim_range[0] == dim_range[1]:
92
+ assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
93
+ elif dim_range[0] < dim_range[1]:
94
+ assert dim_range[0] <= t.dim() <= dim_range[1], \
95
+ f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
96
+
97
+ if data_range[0] < data_range[1]:
98
+ assert data_range[0] <= t.min(), \
99
+ f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
100
+ assert t.max() <= data_range[1], \
101
+ f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
102
+
103
+
104
+ def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
105
+ r"""Reduce input in batch dimension if needed.
106
+
107
+ Args:
108
+ x: Tensor with shape (N, *).
109
+ reduction: Specifies the reduction type:
110
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
111
+ """
112
+ if reduction == 'none':
113
+ return x
114
+ elif reduction == 'mean':
115
+ return x.mean(dim=0)
116
+ elif reduction == 'sum':
117
+ return x.sum(dim=0)
118
+ else:
119
+ raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
120
+
121
+
122
+ def _parse_version(version: Union[str, bytes]) -> Tuple[int, ...]:
123
+ """ Parses valid Python versions according to Semver and PEP 440 specifications.
124
+ For more on Semver check: https://semver.org/
125
+ For more on PEP 440 check: https://www.python.org/dev/peps/pep-0440/.
126
+
127
+ Implementation is inspired by:
128
+ - https://github.com/python-semver
129
+ - https://github.com/pypa/packaging
130
+
131
+ Args:
132
+ version: unparsed information about the library of interest.
133
+
134
+ Returns:
135
+ parsed information about the library of interest.
136
+ """
137
+ if isinstance(version, bytes):
138
+ version = version.decode("UTF-8")
139
+ elif not isinstance(version, str) and not isinstance(version, bytes):
140
+ raise TypeError(f"not expecting type {type(version)}")
141
+
142
+ # Semver processing
143
+ match = SEMVER_VERSION_PATTERN.match(version)
144
+ if match:
145
+ matched_version_parts: Dict[str, Any] = match.groupdict()
146
+ release = tuple([int(matched_version_parts[k]) for k in ['major', 'minor', 'patch']])
147
+ return release
148
+
149
+ # PEP 440 processing
150
+ regex = re.compile(r"^\s*" + PEP_440_VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
151
+ match = regex.search(version)
152
+
153
+ if match is None:
154
+ warnings.warn(f"{version} is not a valid SemVer or PEP 440 string")
155
+ return tuple()
156
+
157
+ release = tuple(int(i) for i in match.group("release").split("."))
158
+ return release
libs/metric/pytorch_fid/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = '0.3.0'
2
+
3
+ import torch
4
+ from einops import rearrange, repeat
5
+
6
+ from .inception import InceptionV3
7
+ from .fid_score import calculate_frechet_distance
8
+
9
+
10
+ class PytorchFIDFactory(torch.nn.Module):
11
+ """
12
+
13
+ Args:
14
+ channels:
15
+ inception_block_idx:
16
+
17
+ Examples:
18
+ >>> fid_factory = PytorchFIDFactory()
19
+ >>> fid_score = fid_factory.score(real_samples=data, fake_samples=all_images)
20
+ >>> print(fid_score)
21
+ """
22
+
23
+ def __init__(self, channels: int = 3, inception_block_idx: int = 2048):
24
+ super().__init__()
25
+ self.channels = channels
26
+
27
+ # load models
28
+ assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
29
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
30
+ self.inception_v3 = InceptionV3([block_idx])
31
+
32
+ @torch.no_grad()
33
+ def calculate_activation_statistics(self, samples):
34
+ features = self.inception_v3(samples)[0]
35
+ features = rearrange(features, '... 1 1 -> ...')
36
+
37
+ mu = torch.mean(features, dim=0).cpu()
38
+ sigma = torch.cov(features).cpu()
39
+ return mu, sigma
40
+
41
+ def score(self, real_samples, fake_samples):
42
+ if self.channels == 1:
43
+ real_samples, fake_samples = map(
44
+ lambda t: repeat(t, 'b 1 ... -> b c ...', c=3), (real_samples, fake_samples)
45
+ )
46
+
47
+ min_batch = min(real_samples.shape[0], fake_samples.shape[0])
48
+ real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples))
49
+
50
+ m1, s1 = self.calculate_activation_statistics(real_samples)
51
+ m2, s2 = self.calculate_activation_statistics(fake_samples)
52
+
53
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
54
+ return fid_value
libs/metric/pytorch_fid/fid_score.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+
3
+ The FID metric calculates the distance between two distributions of images.
4
+ Typically, we have summary statistics (mean & covariance matrix) of one
5
+ of these distributions, while the 2nd distribution is given by a GAN.
6
+
7
+ When run as a stand-alone program, it compares the distribution of
8
+ images that are stored as PNG/JPEG at a specified location with a
9
+ distribution given by summary statistics (in pickle format).
10
+
11
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
12
+ the pool_3 layer of the inception net for generated samples and real world
13
+ samples respectively.
14
+
15
+ See --help to see further details.
16
+
17
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18
+ of Tensorflow
19
+
20
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
21
+
22
+ Licensed under the Apache License, Version 2.0 (the "License");
23
+ you may not use this file except in compliance with the License.
24
+ You may obtain a copy of the License at
25
+
26
+ http://www.apache.org/licenses/LICENSE-2.0
27
+
28
+ Unless required by applicable law or agreed to in writing, software
29
+ distributed under the License is distributed on an "AS IS" BASIS,
30
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ See the License for the specific language governing permissions and
32
+ limitations under the License.
33
+ """
34
+ import os
35
+ import pathlib
36
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
37
+
38
+ import numpy as np
39
+ import torch
40
+ import torchvision.transforms as TF
41
+ from PIL import Image
42
+ from scipy import linalg
43
+ from torch.nn.functional import adaptive_avg_pool2d
44
+
45
+ try:
46
+ from tqdm import tqdm
47
+ except ImportError:
48
+ # If tqdm is not available, provide a mock version of it
49
+ def tqdm(x):
50
+ return x
51
+
52
+ from .inception import InceptionV3
53
+
54
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
55
+ parser.add_argument('--batch-size', type=int, default=50,
56
+ help='Batch size to use')
57
+ parser.add_argument('--num-workers', type=int,
58
+ help=('Number of processes to use for data loading. '
59
+ 'Defaults to `min(8, num_cpus)`'))
60
+ parser.add_argument('--device', type=str, default=None,
61
+ help='Device to use. Like cuda, cuda:0 or cpu')
62
+ parser.add_argument('--dims', type=int, default=2048,
63
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
64
+ help=('Dimensionality of Inception features to use. '
65
+ 'By default, uses pool3 features'))
66
+ parser.add_argument('--save-stats', action='store_true',
67
+ help=('Generate an npz archive from a directory of samples. '
68
+ 'The first path is used as input and the second as output.'))
69
+ parser.add_argument('path', type=str, nargs=2,
70
+ help=('Paths to the generated images or '
71
+ 'to .npz statistic files'))
72
+
73
+ IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
74
+ 'tif', 'tiff', 'webp'}
75
+
76
+
77
+ class ImagePathDataset(torch.utils.data.Dataset):
78
+ def __init__(self, files, transforms=None):
79
+ self.files = files
80
+ self.transforms = transforms
81
+
82
+ def __len__(self):
83
+ return len(self.files)
84
+
85
+ def __getitem__(self, i):
86
+ path = self.files[i]
87
+ img = Image.open(path).convert('RGB')
88
+ if self.transforms is not None:
89
+ img = self.transforms(img)
90
+ return img
91
+
92
+
93
+ def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
94
+ num_workers=1):
95
+ """Calculates the activations of the pool_3 layer for all images.
96
+
97
+ Params:
98
+ -- files : List of image files paths
99
+ -- model : Instance of inception model
100
+ -- batch_size : Batch size of images for the model to process at once.
101
+ Make sure that the number of samples is a multiple of
102
+ the batch size, otherwise some samples are ignored. This
103
+ behavior is retained to match the original FID score
104
+ implementation.
105
+ -- dims : Dimensionality of features returned by Inception
106
+ -- device : Device to run calculations
107
+ -- num_workers : Number of parallel dataloader workers
108
+
109
+ Returns:
110
+ -- A numpy array of dimension (num images, dims) that contains the
111
+ activations of the given tensor when feeding inception with the
112
+ query tensor.
113
+ """
114
+ model.eval()
115
+
116
+ if batch_size > len(files):
117
+ print(('Warning: batch size is bigger than the data size. '
118
+ 'Setting batch size to data size'))
119
+ batch_size = len(files)
120
+
121
+ dataset = ImagePathDataset(files, transforms=TF.ToTensor())
122
+ dataloader = torch.utils.data.DataLoader(dataset,
123
+ batch_size=batch_size,
124
+ shuffle=False,
125
+ drop_last=False,
126
+ num_workers=num_workers)
127
+
128
+ pred_arr = np.empty((len(files), dims))
129
+
130
+ start_idx = 0
131
+
132
+ for batch in tqdm(dataloader):
133
+ batch = batch.to(device)
134
+
135
+ with torch.no_grad():
136
+ pred = model(batch)[0]
137
+
138
+ # If model output is not scalar, apply global spatial average pooling.
139
+ # This happens if you choose a dimensionality not equal 2048.
140
+ if pred.size(2) != 1 or pred.size(3) != 1:
141
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
142
+
143
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
144
+
145
+ pred_arr[start_idx:start_idx + pred.shape[0]] = pred
146
+
147
+ start_idx = start_idx + pred.shape[0]
148
+
149
+ return pred_arr
150
+
151
+
152
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
153
+ """Numpy implementation of the Frechet Distance.
154
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
155
+ and X_2 ~ N(mu_2, C_2) is
156
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
157
+
158
+ Stable version by Dougal J. Sutherland.
159
+
160
+ Params:
161
+ -- mu1 : Numpy array containing the activations of a layer of the
162
+ inception net (like returned by the function 'get_predictions')
163
+ for generated samples.
164
+ -- mu2 : The sample mean over activations, precalculated on an
165
+ representative data set.
166
+ -- sigma1: The covariance matrix over activations for generated samples.
167
+ -- sigma2: The covariance matrix over activations, precalculated on an
168
+ representative data set.
169
+
170
+ Returns:
171
+ -- : The Frechet Distance.
172
+ """
173
+
174
+ mu1 = np.atleast_1d(mu1)
175
+ mu2 = np.atleast_1d(mu2)
176
+
177
+ sigma1 = np.atleast_2d(sigma1)
178
+ sigma2 = np.atleast_2d(sigma2)
179
+
180
+ assert mu1.shape == mu2.shape, \
181
+ 'Training and test mean vectors have different lengths'
182
+ assert sigma1.shape == sigma2.shape, \
183
+ 'Training and test covariances have different dimensions'
184
+
185
+ diff = mu1 - mu2
186
+
187
+ # Product might be almost singular
188
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
189
+ if not np.isfinite(covmean).all():
190
+ msg = ('fid calculation produces singular product; '
191
+ 'adding %s to diagonal of cov estimates') % eps
192
+ print(msg)
193
+ offset = np.eye(sigma1.shape[0]) * eps
194
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
195
+
196
+ # Numerical error might give slight imaginary component
197
+ if np.iscomplexobj(covmean):
198
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
199
+ m = np.max(np.abs(covmean.imag))
200
+ raise ValueError('Imaginary component {}'.format(m))
201
+ covmean = covmean.real
202
+
203
+ tr_covmean = np.trace(covmean)
204
+
205
+ return (diff.dot(diff) + np.trace(sigma1)
206
+ + np.trace(sigma2) - 2 * tr_covmean)
207
+
208
+
209
+ def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
210
+ device='cpu', num_workers=1):
211
+ """Calculation of the statistics used by the FID.
212
+ Params:
213
+ -- files : List of image files paths
214
+ -- model : Instance of inception model
215
+ -- batch_size : The images numpy array is split into batches with
216
+ batch size batch_size. A reasonable batch size
217
+ depends on the hardware.
218
+ -- dims : Dimensionality of features returned by Inception
219
+ -- device : Device to run calculations
220
+ -- num_workers : Number of parallel dataloader workers
221
+
222
+ Returns:
223
+ -- mu : The mean over samples of the activations of the pool_3 layer of
224
+ the inception model.
225
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
226
+ the inception model.
227
+ """
228
+ act = get_activations(files, model, batch_size, dims, device, num_workers)
229
+ mu = np.mean(act, axis=0)
230
+ sigma = np.cov(act, rowvar=False)
231
+ return mu, sigma
232
+
233
+
234
+ def compute_statistics_of_path(path, model, batch_size, dims, device,
235
+ num_workers=1):
236
+ if path.endswith('.npz'):
237
+ with np.load(path) as f:
238
+ m, s = f['mu'][:], f['sigma'][:]
239
+ else:
240
+ path = pathlib.Path(path)
241
+ files = sorted([file for ext in IMAGE_EXTENSIONS
242
+ for file in path.glob('*.{}'.format(ext))])
243
+ m, s = calculate_activation_statistics(files, model, batch_size,
244
+ dims, device, num_workers)
245
+
246
+ return m, s
247
+
248
+
249
+ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
250
+ """Calculates the FID of two paths"""
251
+ for p in paths:
252
+ if not os.path.exists(p):
253
+ raise RuntimeError('Invalid path: %s' % p)
254
+
255
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
256
+
257
+ model = InceptionV3([block_idx]).to(device)
258
+
259
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
260
+ dims, device, num_workers)
261
+ m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
262
+ dims, device, num_workers)
263
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
264
+
265
+ return fid_value
266
+
267
+
268
+ def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
269
+ """Calculates the FID of two paths"""
270
+ if not os.path.exists(paths[0]):
271
+ raise RuntimeError('Invalid path: %s' % paths[0])
272
+
273
+ if os.path.exists(paths[1]):
274
+ raise RuntimeError('Existing output file: %s' % paths[1])
275
+
276
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
277
+
278
+ model = InceptionV3([block_idx]).to(device)
279
+
280
+ print(f"Saving statistics for {paths[0]}")
281
+
282
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
283
+ dims, device, num_workers)
284
+
285
+ np.savez_compressed(paths[1], mu=m1, sigma=s1)
286
+
287
+
288
+ def main():
289
+ args = parser.parse_args()
290
+
291
+ if args.device is None:
292
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
293
+ else:
294
+ device = torch.device(args.device)
295
+
296
+ if args.num_workers is None:
297
+ try:
298
+ num_cpus = len(os.sched_getaffinity(0))
299
+ except AttributeError:
300
+ # os.sched_getaffinity is not available under Windows, use
301
+ # os.cpu_count instead (which may not return the *available* number
302
+ # of CPUs).
303
+ num_cpus = os.cpu_count()
304
+
305
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
306
+ else:
307
+ num_workers = args.num_workers
308
+
309
+ if args.save_stats:
310
+ save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
311
+ return
312
+
313
+ fid_value = calculate_fid_given_paths(args.path,
314
+ args.batch_size,
315
+ device,
316
+ args.dims,
317
+ num_workers)
318
+ print('FID: ', fid_value)
319
+
320
+
321
+ if __name__ == '__main__':
322
+ main()
libs/metric/pytorch_fid/inception.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = _inception_v3(weights='DEFAULT')
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def _inception_v3(*args, **kwargs):
167
+ """Wraps `torchvision.models.inception_v3`"""
168
+ try:
169
+ version = tuple(map(int, torchvision.__version__.split('.')[:2]))
170
+ except ValueError:
171
+ # Just a caution against weird version strings
172
+ version = (0,)
173
+
174
+ # Skips default weight inititialization if supported by torchvision
175
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
176
+ if version >= (0, 6):
177
+ kwargs['init_weights'] = False
178
+
179
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
180
+ # argument prior to version 0.13.
181
+ if version < (0, 13) and 'weights' in kwargs:
182
+ if kwargs['weights'] == 'DEFAULT':
183
+ kwargs['pretrained'] = True
184
+ elif kwargs['weights'] is None:
185
+ kwargs['pretrained'] = False
186
+ else:
187
+ raise ValueError(
188
+ 'weights=={} not supported in torchvision {}'.format(
189
+ kwargs['weights'], torchvision.__version__
190
+ )
191
+ )
192
+ del kwargs['weights']
193
+
194
+ return torchvision.models.inception_v3(*args, **kwargs)
195
+
196
+
197
+ def fid_inception_v3():
198
+ """Build pretrained Inception model for FID computation
199
+
200
+ The Inception model for FID computation uses a different set of weights
201
+ and has a slightly different structure than torchvision's Inception.
202
+
203
+ This method first constructs torchvision's Inception and then patches the
204
+ necessary parts that are different in the FID Inception model.
205
+ """
206
+ inception = _inception_v3(num_classes=1008,
207
+ aux_logits=False,
208
+ weights=None)
209
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
210
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
211
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
212
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
213
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
214
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
215
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
216
+ inception.Mixed_7b = FIDInceptionE_1(1280)
217
+ inception.Mixed_7c = FIDInceptionE_2(2048)
218
+
219
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
220
+ inception.load_state_dict(state_dict)
221
+ return inception
222
+
223
+
224
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
225
+ """InceptionA block patched for FID computation"""
226
+ def __init__(self, in_channels, pool_features):
227
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
228
+
229
+ def forward(self, x):
230
+ branch1x1 = self.branch1x1(x)
231
+
232
+ branch5x5 = self.branch5x5_1(x)
233
+ branch5x5 = self.branch5x5_2(branch5x5)
234
+
235
+ branch3x3dbl = self.branch3x3dbl_1(x)
236
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
237
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
238
+
239
+ # Patch: Tensorflow's average pool does not use the padded zero's in
240
+ # its average calculation
241
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
242
+ count_include_pad=False)
243
+ branch_pool = self.branch_pool(branch_pool)
244
+
245
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
246
+ return torch.cat(outputs, 1)
247
+
248
+
249
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
250
+ """InceptionC block patched for FID computation"""
251
+ def __init__(self, in_channels, channels_7x7):
252
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
253
+
254
+ def forward(self, x):
255
+ branch1x1 = self.branch1x1(x)
256
+
257
+ branch7x7 = self.branch7x7_1(x)
258
+ branch7x7 = self.branch7x7_2(branch7x7)
259
+ branch7x7 = self.branch7x7_3(branch7x7)
260
+
261
+ branch7x7dbl = self.branch7x7dbl_1(x)
262
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
263
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
264
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
265
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
266
+
267
+ # Patch: Tensorflow's average pool does not use the padded zero's in
268
+ # its average calculation
269
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
270
+ count_include_pad=False)
271
+ branch_pool = self.branch_pool(branch_pool)
272
+
273
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
274
+ return torch.cat(outputs, 1)
275
+
276
+
277
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
278
+ """First InceptionE block patched for FID computation"""
279
+ def __init__(self, in_channels):
280
+ super(FIDInceptionE_1, self).__init__(in_channels)
281
+
282
+ def forward(self, x):
283
+ branch1x1 = self.branch1x1(x)
284
+
285
+ branch3x3 = self.branch3x3_1(x)
286
+ branch3x3 = [
287
+ self.branch3x3_2a(branch3x3),
288
+ self.branch3x3_2b(branch3x3),
289
+ ]
290
+ branch3x3 = torch.cat(branch3x3, 1)
291
+
292
+ branch3x3dbl = self.branch3x3dbl_1(x)
293
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
294
+ branch3x3dbl = [
295
+ self.branch3x3dbl_3a(branch3x3dbl),
296
+ self.branch3x3dbl_3b(branch3x3dbl),
297
+ ]
298
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
299
+
300
+ # Patch: Tensorflow's average pool does not use the padded zero's in
301
+ # its average calculation
302
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
303
+ count_include_pad=False)
304
+ branch_pool = self.branch_pool(branch_pool)
305
+
306
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
307
+ return torch.cat(outputs, 1)
308
+
309
+
310
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
311
+ """Second InceptionE block patched for FID computation"""
312
+ def __init__(self, in_channels):
313
+ super(FIDInceptionE_2, self).__init__(in_channels)
314
+
315
+ def forward(self, x):
316
+ branch1x1 = self.branch1x1(x)
317
+
318
+ branch3x3 = self.branch3x3_1(x)
319
+ branch3x3 = [
320
+ self.branch3x3_2a(branch3x3),
321
+ self.branch3x3_2b(branch3x3),
322
+ ]
323
+ branch3x3 = torch.cat(branch3x3, 1)
324
+
325
+ branch3x3dbl = self.branch3x3dbl_1(x)
326
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
327
+ branch3x3dbl = [
328
+ self.branch3x3dbl_3a(branch3x3dbl),
329
+ self.branch3x3dbl_3b(branch3x3dbl),
330
+ ]
331
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
332
+
333
+ # Patch: The FID Inception model uses max pooling instead of average
334
+ # pooling. This is likely an error in this specific Inception
335
+ # implementation, as other Inception models use average pooling here
336
+ # (which matches the description in the paper).
337
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
338
+ branch_pool = self.branch_pool(branch_pool)
339
+
340
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
341
+ return torch.cat(outputs, 1)
libs/modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
libs/modules/edge_map/DoG/XDoG.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ import numpy as np
7
+ import cv2
8
+ from scipy import ndimage as ndi
9
+ from skimage import filters
10
+
11
+
12
+ class XDoG:
13
+
14
+ def __init__(self,
15
+ gamma=0.98,
16
+ phi=200,
17
+ eps=-0.1,
18
+ sigma=0.8,
19
+ k=10,
20
+ binarize: bool = True):
21
+ """
22
+ XDoG algorithm.
23
+
24
+ Args:
25
+ gamma: Control the size of the Gaussian filter
26
+ phi: Control changes in edge strength
27
+ eps: Threshold for controlling edge strength
28
+ sigma: The standard deviation of the Gaussian filter controls the degree of smoothness
29
+ k: Control the size ratio of Gaussian filter, (k=10 or k=1.6)
30
+ binarize(bool): Whether to binarize the output
31
+ """
32
+
33
+ super(XDoG, self).__init__()
34
+
35
+ self.gamma = gamma
36
+ assert 0 <= self.gamma <= 1
37
+
38
+ self.phi = phi
39
+ assert 0 <= self.phi <= 1500
40
+
41
+ self.eps = eps
42
+ assert -1 <= self.eps <= 1
43
+
44
+ self.sigma = sigma
45
+ assert 0.1 <= self.sigma <= 10
46
+
47
+ self.k = k
48
+ assert 1 <= self.k <= 100
49
+
50
+ self.binarize = binarize
51
+
52
+ def __call__(self, img):
53
+ # to gray if image is not already grayscale
54
+ if len(img.shape) == 3 and img.shape[2] == 3:
55
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
56
+ elif len(img.shape) == 3 and img.shape[2] == 4:
57
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY)
58
+
59
+ if np.isnan(img).any():
60
+ img[np.isnan(img)] = np.mean(img[~np.isnan(img)])
61
+
62
+ # gaussian filter
63
+ imf1 = ndi.gaussian_filter(img, self.sigma)
64
+ imf2 = ndi.gaussian_filter(img, self.sigma * self.k)
65
+ imdiff = imf1 - self.gamma * imf2
66
+
67
+ # XDoG
68
+ imdiff = (imdiff < self.eps) * 1.0 + (imdiff >= self.eps) * (1.0 + np.tanh(self.phi * imdiff))
69
+
70
+ # normalize
71
+ imdiff -= imdiff.min()
72
+ imdiff /= imdiff.max()
73
+
74
+ if self.binarize:
75
+ th = filters.threshold_otsu(imdiff)
76
+ imdiff = (imdiff >= th).astype('float32')
77
+
78
+ return imdiff
libs/modules/edge_map/DoG/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ from .XDoG import XDoG
7
+
8
+ __all__ = ['XDoG']
libs/modules/edge_map/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
libs/modules/edge_map/canny/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ import cv2
7
+
8
+
9
+ class CannyDetector:
10
+
11
+ def __call__(self, img, low_threshold, high_threshold, L2gradient=False):
12
+ return cv2.Canny(img, low_threshold, high_threshold, L2gradient)
13
+
14
+
15
+ __all__ = ['CannyDetector']
libs/modules/edge_map/image_grads/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ from .laplacian import LaplacianDetector
7
+
8
+ __all__ = ['LaplacianDetector']
libs/modules/edge_map/image_grads/laplacian.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+
7
+ import cv2
8
+
9
+
10
+ class LaplacianDetector:
11
+
12
+ def __call__(self, img):
13
+ return cv2.Laplacian(img, cv2.CV_64F)
libs/modules/ema.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Description: EMA model
4
+
5
+ import copy
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ __all__ = ['EMA']
11
+
12
+
13
+ class EMA(nn.Module):
14
+ """
15
+ Implements exponential moving average shadowing for your model.
16
+ Utilizes an inverse decay schedule to manage longer term training runs.
17
+ By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
18
+ @crowsonkb's notes on EMA Warmup:
19
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
20
+ good values for models you plan to train for a million or more steps (reaches decay
21
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
22
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
23
+ 215.4k steps).
24
+ Args:
25
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
26
+ power (float): Exponential factor of EMA warmup. Default: 1.
27
+ min_value (float): The minimum EMA decay rate. Default: 0.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ model,
33
+ # if your model has lazylinears or other types of non-deepcopyable modules,
34
+ # you can pass in your own ema model
35
+ ema_model=None,
36
+ beta=0.9999,
37
+ update_after_step=100,
38
+ update_every=10,
39
+ inv_gamma=1.0,
40
+ power=2 / 3,
41
+ min_value=0.0,
42
+ param_or_buffer_names_no_ema=set(),
43
+ ignore_names=set(),
44
+ ignore_startswith_names=set(),
45
+ # set this to False if you do not wish for the online model to be
46
+ # saved along with the ema model (managed externally)
47
+ include_online_model=True
48
+ ):
49
+ super().__init__()
50
+ self.beta = beta
51
+
52
+ # whether to include the online model within the module tree, so that state_dict also saves it
53
+ self.include_online_model = include_online_model
54
+
55
+ if include_online_model:
56
+ self.online_model = model
57
+ else:
58
+ self.online_model = [model] # hack
59
+
60
+ # ema model
61
+ self.ema_model = ema_model
62
+
63
+ if not exists(self.ema_model):
64
+ try:
65
+ self.ema_model = copy.deepcopy(model)
66
+ except:
67
+ print('Your model was not copyable. Please make sure you are not using any LazyLinear')
68
+ exit()
69
+
70
+ self.ema_model.requires_grad_(False)
71
+
72
+ self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param.dtype == torch.float}
73
+ self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer.dtype == torch.float}
74
+
75
+ self.update_every = update_every
76
+ self.update_after_step = update_after_step
77
+
78
+ self.inv_gamma = inv_gamma
79
+ self.power = power
80
+ self.min_value = min_value
81
+
82
+ assert isinstance(param_or_buffer_names_no_ema, (set, list))
83
+ self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
84
+
85
+ self.ignore_names = ignore_names
86
+ self.ignore_startswith_names = ignore_startswith_names
87
+
88
+ self.register_buffer('initted', torch.Tensor([False]))
89
+ self.register_buffer('step', torch.tensor([0]))
90
+
91
+ @property
92
+ def model(self):
93
+ return self.online_model if self.include_online_model else self.online_model[0]
94
+
95
+ def restore_ema_model_device(self):
96
+ device = self.initted.device
97
+ self.ema_model.to(device)
98
+
99
+ def get_params_iter(self, model):
100
+ for name, param in model.named_parameters():
101
+ if name not in self.parameter_names:
102
+ continue
103
+ yield name, param
104
+
105
+ def get_buffers_iter(self, model):
106
+ for name, buffer in model.named_buffers():
107
+ if name not in self.buffer_names:
108
+ continue
109
+ yield name, buffer
110
+
111
+ def copy_params_from_model_to_ema(self):
112
+ for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model),
113
+ self.get_params_iter(self.model)):
114
+ ma_params.data.copy_(current_params.data)
115
+
116
+ for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model),
117
+ self.get_buffers_iter(self.model)):
118
+ ma_buffers.data.copy_(current_buffers.data)
119
+
120
+ def get_current_decay(self):
121
+ epoch = clamp(self.step.item() - self.update_after_step - 1, min_value=0.)
122
+ value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
123
+
124
+ if epoch <= 0:
125
+ return 0.
126
+
127
+ return clamp(value, min_value=self.min_value, max_value=self.beta)
128
+
129
+ def update(self):
130
+ step = self.step.item()
131
+ self.step += 1
132
+
133
+ if (step % self.update_every) != 0:
134
+ return
135
+
136
+ if step <= self.update_after_step:
137
+ self.copy_params_from_model_to_ema()
138
+ return
139
+
140
+ if not self.initted.item():
141
+ self.copy_params_from_model_to_ema()
142
+ self.initted.data.copy_(torch.Tensor([True]))
143
+
144
+ self.update_moving_average(self.ema_model, self.model)
145
+
146
+ @torch.no_grad()
147
+ def update_moving_average(self, ma_model, current_model):
148
+ current_decay = self.get_current_decay()
149
+
150
+ for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model),
151
+ self.get_params_iter(ma_model)):
152
+ if name in self.ignore_names:
153
+ continue
154
+
155
+ if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
156
+ continue
157
+
158
+ if name in self.param_or_buffer_names_no_ema:
159
+ ma_params.data.copy_(current_params.data)
160
+ continue
161
+
162
+ ma_params.data.lerp_(current_params.data, 1. - current_decay)
163
+
164
+ for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model),
165
+ self.get_buffers_iter(ma_model)):
166
+ if name in self.ignore_names:
167
+ continue
168
+
169
+ if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
170
+ continue
171
+
172
+ if name in self.param_or_buffer_names_no_ema:
173
+ ma_buffer.data.copy_(current_buffer.data)
174
+ continue
175
+
176
+ ma_buffer.data.lerp_(current_buffer.data, 1. - current_decay)
177
+
178
+ def __call__(self, *args, **kwargs):
179
+ return self.ema_model(*args, **kwargs)
180
+
181
+
182
+ def exists(val):
183
+ return val is not None
184
+
185
+
186
+ def is_float_dtype(dtype):
187
+ return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
188
+
189
+
190
+ def clamp(value, min_value=None, max_value=None):
191
+ assert exists(min_value) or exists(max_value)
192
+ if exists(min_value):
193
+ value = max(value, min_value)
194
+
195
+ if exists(max_value):
196
+ value = min(value, max_value)
197
+
198
+ return value
libs/modules/vision/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ from .inception import inception_v3
7
+ from .vgg import VGG
8
+
9
+ __all__ = [
10
+ 'inception_v3',
11
+ 'VGG'
12
+ ]
libs/modules/vision/inception.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ from collections import namedtuple
7
+ import warnings
8
+ from typing import Callable, Any, Optional, Tuple, List
9
+
10
+ import torch
11
+ from torch import nn, Tensor
12
+ import torch.nn.functional as F
13
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
14
+
15
+ __all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
16
+
17
+ model_urls = {
18
+ # Inception v3 ported from TensorFlow
19
+ 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth',
20
+ }
21
+
22
+ InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
23
+ InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}
24
+
25
+ # Script annotations failed with _GoogleNetOutputs = namedtuple ...
26
+ # _InceptionOutputs set here for backwards compat
27
+ _InceptionOutputs = InceptionOutputs
28
+
29
+
30
+ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
31
+ r"""Inception v3 model architecture from
32
+ `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
33
+
34
+ .. note::
35
+ **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
36
+ N x 3 x 299 x 299, so ensure your images are sized accordingly.
37
+
38
+ Args:
39
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
40
+ progress (bool): If True, displays a progress bar of the download to stderr
41
+ aux_logits (bool): If True, add an auxiliary branch that can improve training.
42
+ Default: *True*
43
+ transform_input (bool): If True, preprocesses the input according to the method with which it
44
+ was trained on ImageNet. Default: *False*
45
+ """
46
+ if pretrained:
47
+ if 'transform_input' not in kwargs:
48
+ kwargs['transform_input'] = True
49
+ if 'aux_logits' in kwargs:
50
+ original_aux_logits = kwargs['aux_logits']
51
+ kwargs['aux_logits'] = True
52
+ else:
53
+ original_aux_logits = True
54
+ kwargs['init_weights'] = False # we are loading weights from a pretrained model
55
+ model = Inception3(**kwargs)
56
+ state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
57
+ progress=progress)
58
+ model.load_state_dict(state_dict)
59
+ if not original_aux_logits:
60
+ model.aux_logits = False
61
+ model.AuxLogits = None
62
+ return model
63
+
64
+ return Inception3(**kwargs)
65
+
66
+
67
+ class Inception3(nn.Module):
68
+
69
+ def __init__(
70
+ self,
71
+ num_classes: int = 1000,
72
+ aux_logits: bool = True,
73
+ transform_input: bool = False,
74
+ inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
75
+ init_weights: Optional[bool] = None
76
+ ) -> None:
77
+ super(Inception3, self).__init__()
78
+ if inception_blocks is None:
79
+ inception_blocks = [
80
+ BasicConv2d, InceptionA, InceptionB, InceptionC,
81
+ InceptionD, InceptionE, InceptionAux
82
+ ]
83
+ if init_weights is None:
84
+ warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
85
+ 'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
86
+ ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
87
+ init_weights = True
88
+ assert len(inception_blocks) == 7
89
+ conv_block = inception_blocks[0]
90
+ inception_a = inception_blocks[1]
91
+ inception_b = inception_blocks[2]
92
+ inception_c = inception_blocks[3]
93
+ inception_d = inception_blocks[4]
94
+ inception_e = inception_blocks[5]
95
+ inception_aux = inception_blocks[6]
96
+
97
+ self.aux_logits = aux_logits
98
+ self.transform_input = transform_input
99
+ self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
100
+ self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
101
+ self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
102
+ self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
103
+ self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
104
+ self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
105
+ self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
106
+ self.Mixed_5b = inception_a(192, pool_features=32)
107
+ self.Mixed_5c = inception_a(256, pool_features=64)
108
+ self.Mixed_5d = inception_a(288, pool_features=64)
109
+ self.Mixed_6a = inception_b(288)
110
+ self.Mixed_6b = inception_c(768, channels_7x7=128)
111
+ self.Mixed_6c = inception_c(768, channels_7x7=160)
112
+ self.Mixed_6d = inception_c(768, channels_7x7=160)
113
+ self.Mixed_6e = inception_c(768, channels_7x7=192)
114
+ self.AuxLogits: Optional[nn.Module] = None
115
+ if aux_logits:
116
+ self.AuxLogits = inception_aux(768, num_classes)
117
+ self.Mixed_7a = inception_d(768)
118
+ self.Mixed_7b = inception_e(1280)
119
+ self.Mixed_7c = inception_e(2048)
120
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
121
+ self.dropout = nn.Dropout()
122
+ self.fc = nn.Linear(2048, num_classes)
123
+ if init_weights:
124
+ for m in self.modules():
125
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
126
+ import scipy.stats as stats
127
+ stddev = m.stddev if hasattr(m, 'stddev') else 0.1
128
+ X = stats.truncnorm(-2, 2, scale=stddev)
129
+ values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
130
+ values = values.view(m.weight.size())
131
+ with torch.no_grad():
132
+ m.weight.copy_(values)
133
+ elif isinstance(m, nn.BatchNorm2d):
134
+ nn.init.constant_(m.weight, 1)
135
+ nn.init.constant_(m.bias, 0)
136
+
137
+ def _transform_input(self, x: Tensor) -> Tensor:
138
+ if self.transform_input:
139
+ x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
140
+ x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
141
+ x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
142
+ x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
143
+ return x
144
+
145
+ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
146
+ # N x 3 x 299 x 299
147
+ x = self.Conv2d_1a_3x3(x)
148
+ # N x 32 x 149 x 149
149
+ x = self.Conv2d_2a_3x3(x)
150
+ # N x 32 x 147 x 147
151
+ x = self.Conv2d_2b_3x3(x)
152
+ # N x 64 x 147 x 147
153
+ feat = self.maxpool1(x)
154
+ # N x 64 x 73 x 73
155
+ x = self.Conv2d_3b_1x1(feat)
156
+ # N x 80 x 73 x 73
157
+ x = self.Conv2d_4a_3x3(x)
158
+ # N x 192 x 71 x 71
159
+ x = self.maxpool2(x)
160
+ # N x 192 x 35 x 35
161
+ x = self.Mixed_5b(x)
162
+ # N x 256 x 35 x 35
163
+ x = self.Mixed_5c(x)
164
+ # N x 288 x 35 x 35
165
+ x = self.Mixed_5d(x)
166
+ # N x 288 x 35 x 35
167
+ x = self.Mixed_6a(x)
168
+ # N x 768 x 17 x 17
169
+ x = self.Mixed_6b(x)
170
+ # N x 768 x 17 x 17
171
+ x = self.Mixed_6c(x)
172
+ # N x 768 x 17 x 17
173
+ x = self.Mixed_6d(x)
174
+ # N x 768 x 17 x 17
175
+ x = self.Mixed_6e(x)
176
+ # N x 768 x 17 x 17
177
+ aux: Optional[Tensor] = None
178
+ if self.AuxLogits is not None:
179
+ if self.training:
180
+ aux = self.AuxLogits(x)
181
+ # N x 768 x 17 x 17
182
+ x = self.Mixed_7a(x)
183
+ # N x 1280 x 8 x 8
184
+ x = self.Mixed_7b(x)
185
+ # N x 2048 x 8 x 8
186
+ x = self.Mixed_7c(x)
187
+ # N x 2048 x 8 x 8
188
+ # Adaptive average pooling
189
+ x = self.avgpool(x)
190
+ # N x 2048 x 1 x 1
191
+ x = self.dropout(x)
192
+ # N x 2048 x 1 x 1
193
+ x = torch.flatten(x, 1)
194
+ # N x 2048
195
+ x = self.fc(x)
196
+ # N x 1000 (num_classes)
197
+ return feat, x, aux
198
+
199
+ @torch.jit.unused
200
+ def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
201
+ if self.training and self.aux_logits:
202
+ return InceptionOutputs(x, aux)
203
+ else:
204
+ return x # type: ignore[return-value]
205
+
206
+ def forward(self, x: Tensor) -> InceptionOutputs:
207
+ x = self._transform_input(x)
208
+ feat, x, aux = self._forward(x)
209
+ aux_defined = self.training and self.aux_logits
210
+ if torch.jit.is_scripting():
211
+ if not aux_defined:
212
+ warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
213
+ return feat, InceptionOutputs(x, aux)
214
+ else:
215
+ return feat, self.eager_outputs(x, aux)
216
+
217
+
218
+ class InceptionA(nn.Module):
219
+
220
+ def __init__(
221
+ self,
222
+ in_channels: int,
223
+ pool_features: int,
224
+ conv_block: Optional[Callable[..., nn.Module]] = None
225
+ ) -> None:
226
+ super(InceptionA, self).__init__()
227
+ if conv_block is None:
228
+ conv_block = BasicConv2d
229
+ self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
230
+
231
+ self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
232
+ self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
233
+
234
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
235
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
236
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
237
+
238
+ self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
239
+
240
+ def _forward(self, x: Tensor) -> List[Tensor]:
241
+ branch1x1 = self.branch1x1(x)
242
+
243
+ branch5x5 = self.branch5x5_1(x)
244
+ branch5x5 = self.branch5x5_2(branch5x5)
245
+
246
+ branch3x3dbl = self.branch3x3dbl_1(x)
247
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
248
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
249
+
250
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
251
+ branch_pool = self.branch_pool(branch_pool)
252
+
253
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
254
+ return outputs
255
+
256
+ def forward(self, x: Tensor) -> Tensor:
257
+ outputs = self._forward(x)
258
+ return torch.cat(outputs, 1)
259
+
260
+
261
+ class InceptionB(nn.Module):
262
+
263
+ def __init__(
264
+ self,
265
+ in_channels: int,
266
+ conv_block: Optional[Callable[..., nn.Module]] = None
267
+ ) -> None:
268
+ super(InceptionB, self).__init__()
269
+ if conv_block is None:
270
+ conv_block = BasicConv2d
271
+ self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
272
+
273
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
274
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
275
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
276
+
277
+ def _forward(self, x: Tensor) -> List[Tensor]:
278
+ branch3x3 = self.branch3x3(x)
279
+
280
+ branch3x3dbl = self.branch3x3dbl_1(x)
281
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
282
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
283
+
284
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
285
+
286
+ outputs = [branch3x3, branch3x3dbl, branch_pool]
287
+ return outputs
288
+
289
+ def forward(self, x: Tensor) -> Tensor:
290
+ outputs = self._forward(x)
291
+ return torch.cat(outputs, 1)
292
+
293
+
294
+ class InceptionC(nn.Module):
295
+
296
+ def __init__(
297
+ self,
298
+ in_channels: int,
299
+ channels_7x7: int,
300
+ conv_block: Optional[Callable[..., nn.Module]] = None
301
+ ) -> None:
302
+ super(InceptionC, self).__init__()
303
+ if conv_block is None:
304
+ conv_block = BasicConv2d
305
+ self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
306
+
307
+ c7 = channels_7x7
308
+ self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
309
+ self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
310
+ self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
311
+
312
+ self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
313
+ self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
314
+ self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
315
+ self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
316
+ self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
317
+
318
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
319
+
320
+ def _forward(self, x: Tensor) -> List[Tensor]:
321
+ branch1x1 = self.branch1x1(x)
322
+
323
+ branch7x7 = self.branch7x7_1(x)
324
+ branch7x7 = self.branch7x7_2(branch7x7)
325
+ branch7x7 = self.branch7x7_3(branch7x7)
326
+
327
+ branch7x7dbl = self.branch7x7dbl_1(x)
328
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
329
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
330
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
331
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
332
+
333
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
334
+ branch_pool = self.branch_pool(branch_pool)
335
+
336
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
337
+ return outputs
338
+
339
+ def forward(self, x: Tensor) -> Tensor:
340
+ outputs = self._forward(x)
341
+ return torch.cat(outputs, 1)
342
+
343
+
344
+ class InceptionD(nn.Module):
345
+
346
+ def __init__(
347
+ self,
348
+ in_channels: int,
349
+ conv_block: Optional[Callable[..., nn.Module]] = None
350
+ ) -> None:
351
+ super(InceptionD, self).__init__()
352
+ if conv_block is None:
353
+ conv_block = BasicConv2d
354
+ self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
355
+ self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
356
+
357
+ self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
358
+ self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
359
+ self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
360
+ self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
361
+
362
+ def _forward(self, x: Tensor) -> List[Tensor]:
363
+ branch3x3 = self.branch3x3_1(x)
364
+ branch3x3 = self.branch3x3_2(branch3x3)
365
+
366
+ branch7x7x3 = self.branch7x7x3_1(x)
367
+ branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
368
+ branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
369
+ branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
370
+
371
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
372
+ outputs = [branch3x3, branch7x7x3, branch_pool]
373
+ return outputs
374
+
375
+ def forward(self, x: Tensor) -> Tensor:
376
+ outputs = self._forward(x)
377
+ return torch.cat(outputs, 1)
378
+
379
+
380
+ class InceptionE(nn.Module):
381
+
382
+ def __init__(
383
+ self,
384
+ in_channels: int,
385
+ conv_block: Optional[Callable[..., nn.Module]] = None
386
+ ) -> None:
387
+ super(InceptionE, self).__init__()
388
+ if conv_block is None:
389
+ conv_block = BasicConv2d
390
+ self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
391
+
392
+ self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
393
+ self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
394
+ self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
395
+
396
+ self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
397
+ self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
398
+ self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
399
+ self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
400
+
401
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
402
+
403
+ def _forward(self, x: Tensor) -> List[Tensor]:
404
+ branch1x1 = self.branch1x1(x)
405
+
406
+ branch3x3 = self.branch3x3_1(x)
407
+ branch3x3 = [
408
+ self.branch3x3_2a(branch3x3),
409
+ self.branch3x3_2b(branch3x3),
410
+ ]
411
+ branch3x3 = torch.cat(branch3x3, 1)
412
+
413
+ branch3x3dbl = self.branch3x3dbl_1(x)
414
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
415
+ branch3x3dbl = [
416
+ self.branch3x3dbl_3a(branch3x3dbl),
417
+ self.branch3x3dbl_3b(branch3x3dbl),
418
+ ]
419
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
420
+
421
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
422
+ branch_pool = self.branch_pool(branch_pool)
423
+
424
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
425
+ return outputs
426
+
427
+ def forward(self, x: Tensor) -> Tensor:
428
+ outputs = self._forward(x)
429
+ return torch.cat(outputs, 1)
430
+
431
+
432
+ class InceptionAux(nn.Module):
433
+
434
+ def __init__(
435
+ self,
436
+ in_channels: int,
437
+ num_classes: int,
438
+ conv_block: Optional[Callable[..., nn.Module]] = None
439
+ ) -> None:
440
+ super(InceptionAux, self).__init__()
441
+ if conv_block is None:
442
+ conv_block = BasicConv2d
443
+ self.conv0 = conv_block(in_channels, 128, kernel_size=1)
444
+ self.conv1 = conv_block(128, 768, kernel_size=5)
445
+ self.conv1.stddev = 0.01 # type: ignore[assignment]
446
+ self.fc = nn.Linear(768, num_classes)
447
+ self.fc.stddev = 0.001 # type: ignore[assignment]
448
+
449
+ def forward(self, x: Tensor) -> Tensor:
450
+ # N x 768 x 17 x 17
451
+ x = F.avg_pool2d(x, kernel_size=5, stride=3)
452
+ # N x 768 x 5 x 5
453
+ x = self.conv0(x)
454
+ # N x 128 x 5 x 5
455
+ x = self.conv1(x)
456
+ # N x 768 x 1 x 1
457
+ # Adaptive average pooling
458
+ x = F.adaptive_avg_pool2d(x, (1, 1))
459
+ # N x 768 x 1 x 1
460
+ x = torch.flatten(x, 1)
461
+ # N x 768
462
+ x = self.fc(x)
463
+ # N x 1000
464
+ return x
465
+
466
+
467
+ class BasicConv2d(nn.Module):
468
+
469
+ def __init__(
470
+ self,
471
+ in_channels: int,
472
+ out_channels: int,
473
+ **kwargs: Any
474
+ ) -> None:
475
+ super(BasicConv2d, self).__init__()
476
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
477
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
478
+
479
+ def forward(self, x: Tensor) -> Tensor:
480
+ x = self.conv(x)
481
+ x = self.bn(x)
482
+ return F.relu(x, inplace=True)
libs/modules/vision/vgg.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ from typing import Union, List, Dict, Any, cast
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
11
+
12
+ __all__ = [
13
+ 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
14
+ 'vgg19_bn', 'vgg19',
15
+ ]
16
+
17
+ model_urls = {
18
+ 'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth',
19
+ 'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth',
20
+ 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
21
+ 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
22
+ 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
23
+ 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
24
+ 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
25
+ 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
26
+ }
27
+
28
+
29
+ class VGG(nn.Module):
30
+
31
+ def __init__(
32
+ self,
33
+ features: nn.Module,
34
+ num_classes: int = 1000,
35
+ init_weights: bool = True
36
+ ) -> None:
37
+ super(VGG, self).__init__()
38
+ self.features = features
39
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
40
+ self.classifier = nn.Sequential(
41
+ nn.Linear(512 * 7 * 7, 4096),
42
+ nn.ReLU(True),
43
+ nn.Dropout(),
44
+ nn.Linear(4096, 4096),
45
+ nn.ReLU(True),
46
+ nn.Dropout(),
47
+ nn.Linear(4096, num_classes),
48
+ )
49
+ if init_weights:
50
+ self._initialize_weights()
51
+
52
+ def forward(self, x: torch.Tensor):
53
+ feat = self.features(x)
54
+ x = self.avgpool(feat)
55
+ x = torch.flatten(x, 1)
56
+ x = self.classifier(x)
57
+ return feat, x
58
+
59
+ def _initialize_weights(self) -> None:
60
+ for m in self.modules():
61
+ if isinstance(m, nn.Conv2d):
62
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
63
+ if m.bias is not None:
64
+ nn.init.constant_(m.bias, 0)
65
+ elif isinstance(m, nn.BatchNorm2d):
66
+ nn.init.constant_(m.weight, 1)
67
+ nn.init.constant_(m.bias, 0)
68
+ elif isinstance(m, nn.Linear):
69
+ nn.init.normal_(m.weight, 0, 0.01)
70
+ nn.init.constant_(m.bias, 0)
71
+
72
+
73
+ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
74
+ layers: List[nn.Module] = []
75
+ in_channels = 3
76
+ for v in cfg:
77
+ if v == 'M':
78
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
79
+ else:
80
+ v = cast(int, v)
81
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
82
+ if batch_norm:
83
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
84
+ else:
85
+ layers += [conv2d, nn.ReLU(inplace=True)]
86
+ in_channels = v
87
+ return nn.Sequential(*layers)
88
+
89
+
90
+ cfgs: Dict[str, List[Union[str, int]]] = {
91
+ 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
92
+ 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
93
+ 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
94
+ 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
95
+ }
96
+
97
+
98
+ def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
99
+ if pretrained:
100
+ kwargs['init_weights'] = False
101
+ model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
102
+ if pretrained:
103
+ state_dict = load_state_dict_from_url(model_urls[arch],
104
+ progress=progress)
105
+ model.load_state_dict(state_dict)
106
+ return model
107
+
108
+
109
+ def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
110
+ r"""VGG 11-layer model (configuration "A") from
111
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
112
+
113
+ Args:
114
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
115
+ progress (bool): If True, displays a progress bar of the download to stderr
116
+ """
117
+ return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
118
+
119
+
120
+ def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
121
+ r"""VGG 11-layer model (configuration "A") with batch normalization
122
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
123
+
124
+ Args:
125
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
126
+ progress (bool): If True, displays a progress bar of the download to stderr
127
+ """
128
+ return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
129
+
130
+
131
+ def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
132
+ r"""VGG 13-layer model (configuration "B")
133
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
134
+
135
+ Args:
136
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
137
+ progress (bool): If True, displays a progress bar of the download to stderr
138
+ """
139
+ return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
140
+
141
+
142
+ def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
143
+ r"""VGG 13-layer model (configuration "B") with batch normalization
144
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
145
+
146
+ Args:
147
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
148
+ progress (bool): If True, displays a progress bar of the download to stderr
149
+ """
150
+ return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
151
+
152
+
153
+ def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
154
+ r"""VGG 16-layer model (configuration "D")
155
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
156
+
157
+ Args:
158
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
159
+ progress (bool): If True, displays a progress bar of the download to stderr
160
+ """
161
+ return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
162
+
163
+
164
+ def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
165
+ r"""VGG 16-layer model (configuration "D") with batch normalization
166
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
167
+
168
+ Args:
169
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
170
+ progress (bool): If True, displays a progress bar of the download to stderr
171
+ """
172
+ return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
173
+
174
+
175
+ def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
176
+ r"""VGG 19-layer model (configuration "E")
177
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
178
+
179
+ Args:
180
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
181
+ progress (bool): If True, displays a progress bar of the download to stderr
182
+ """
183
+ return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
184
+
185
+
186
+ def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
187
+ r"""VGG 19-layer model (configuration 'E') with batch normalization
188
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
189
+
190
+ Args:
191
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
192
+ progress (bool): If True, displays a progress bar of the download to stderr
193
+ """
194
+ return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
libs/modules/visual/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
libs/modules/visual/imshow.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ import pathlib
7
+ from typing import Union, List, Text, BinaryIO, AnyStr
8
+
9
+ import matplotlib.pyplot as plt
10
+ import torch
11
+ import torchvision.transforms as transforms
12
+ from torchvision.utils import make_grid
13
+
14
+ __all__ = [
15
+ 'sample2pil_transforms',
16
+ 'pt2numpy_transforms',
17
+ 'plt_pt_img',
18
+ 'save_grid_images_and_labels',
19
+ 'save_grid_images_and_captions',
20
+ ]
21
+
22
+ # generate sample to PIL images
23
+ sample2pil_transforms = transforms.Compose([
24
+ # unnormalizing to [0,1]
25
+ transforms.Lambda(lambda t: torch.clamp((t + 1) / 2, min=0.0, max=1.0)),
26
+ # Add 0.5 after unnormalizing to [0, 255]
27
+ transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)),
28
+ # CHW to HWC
29
+ transforms.Lambda(lambda t: t.permute(1, 2, 0)),
30
+ # to numpy ndarray, dtype int8
31
+ transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()),
32
+ # Converts a numpy ndarray of shape H x W x C to a PIL Image
33
+ transforms.ToPILImage(),
34
+ ])
35
+
36
+ # generate sample to PIL images
37
+ pt2numpy_transforms = transforms.Compose([
38
+ # Add 0.5 after unnormalizing to [0, 255]
39
+ transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)),
40
+ # CHW to HWC
41
+ transforms.Lambda(lambda t: t.permute(1, 2, 0)),
42
+ # to numpy ndarray, dtype int8
43
+ transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()),
44
+ ])
45
+
46
+
47
+ def plt_pt_img(
48
+ pt_img: torch.Tensor,
49
+ save_path: AnyStr = None,
50
+ title: AnyStr = None,
51
+ dpi: int = 300
52
+ ):
53
+ grid = make_grid(pt_img, normalize=True, pad_value=2)
54
+ ndarr = pt2numpy_transforms(grid)
55
+ plt.imshow(ndarr)
56
+ plt.axis("off")
57
+ plt.tight_layout()
58
+ if title is not None:
59
+ plt.title(f"{title}")
60
+
61
+ plt.show()
62
+ if save_path is not None:
63
+ plt.savefig(save_path, dpi=dpi)
64
+
65
+ plt.close()
66
+
67
+
68
+ @torch.no_grad()
69
+ def save_grid_images_and_labels(
70
+ images: Union[torch.Tensor, List[torch.Tensor]],
71
+ probs: Union[torch.Tensor, List[torch.Tensor]],
72
+ labels: Union[torch.Tensor, List[torch.Tensor]],
73
+ classes: Union[torch.Tensor, List[torch.Tensor]],
74
+ fp: Union[Text, pathlib.Path, BinaryIO],
75
+ nrow: int = 4,
76
+ normalize: bool = True
77
+ ) -> None:
78
+ """Save a given Tensor into an image file.
79
+ """
80
+ num_images = len(images)
81
+ num_rows, num_cols = _get_subplot_shape(num_images, nrow)
82
+
83
+ fig = plt.figure(figsize=(25, 20))
84
+
85
+ for i in range(num_images):
86
+ ax = fig.add_subplot(num_rows, num_cols, i + 1)
87
+
88
+ image, true_label, prob = images[i], labels[i], probs[i]
89
+
90
+ true_prob = prob[true_label]
91
+ incorrect_prob, incorrect_label = torch.max(prob, dim=0)
92
+ true_class = classes[true_label]
93
+
94
+ incorrect_class = classes[incorrect_label]
95
+
96
+ if normalize:
97
+ image = sample2pil_transforms(image)
98
+
99
+ ax.imshow(image)
100
+ title = f'true label: {true_class} ({true_prob:.3f})\n ' \
101
+ f'pred label: {incorrect_class} ({incorrect_prob:.3f})'
102
+ ax.set_title(title, fontsize=20)
103
+ ax.axis('off')
104
+
105
+ fig.subplots_adjust(hspace=0.3)
106
+
107
+ plt.savefig(fp)
108
+ plt.close()
109
+
110
+
111
+ @torch.no_grad()
112
+ def save_grid_images_and_captions(
113
+ images: Union[torch.Tensor, List[torch.Tensor]],
114
+ captions: List,
115
+ fp: Union[Text, pathlib.Path, BinaryIO],
116
+ nrow: int = 4,
117
+ normalize: bool = True
118
+ ) -> None:
119
+ """
120
+ Save a grid of images and their captions into an image file.
121
+
122
+ Args:
123
+ images (Union[torch.Tensor, List[torch.Tensor]]): A list of images to display.
124
+ captions (List): A list of captions for each image.
125
+ fp (Union[Text, pathlib.Path, BinaryIO]): The file path to save the image to.
126
+ nrow (int, optional): The number of images to display in each row. Defaults to 4.
127
+ normalize (bool, optional): Whether to normalize the image or not. Defaults to False.
128
+ """
129
+ num_images = len(images)
130
+ num_rows, num_cols = _get_subplot_shape(num_images, nrow)
131
+
132
+ fig = plt.figure(figsize=(25, 20))
133
+
134
+ for i in range(num_images):
135
+ ax = fig.add_subplot(num_rows, num_cols, i + 1)
136
+ image, caption = images[i], captions[i]
137
+
138
+ if normalize:
139
+ image = sample2pil_transforms(image)
140
+
141
+ ax.imshow(image)
142
+ title = f'"{caption}"' if num_images > 1 else f'"{captions}"'
143
+ title = _insert_newline(title)
144
+ ax.set_title(title, fontsize=20)
145
+ ax.axis('off')
146
+
147
+ fig.subplots_adjust(hspace=0.3)
148
+
149
+ plt.savefig(fp)
150
+ plt.close()
151
+
152
+
153
+ def _get_subplot_shape(num_images, nrow):
154
+ """
155
+ Calculate the number of rows and columns required to display images in a grid.
156
+
157
+ Args:
158
+ num_images (int): The total number of images to display.
159
+ nrow (int): The maximum number of images to display in each row.
160
+
161
+ Returns:
162
+ Tuple[int, int]: The number of rows and columns required to display images in a grid.
163
+ """
164
+ num_cols = min(num_images, nrow)
165
+ num_rows = (num_images + num_cols - 1) // num_cols
166
+ return num_rows, num_cols
167
+
168
+
169
+ def _insert_newline(string, point=9):
170
+ # split by blank
171
+ words = string.split()
172
+ if len(words) <= point:
173
+ return string
174
+
175
+ word_chunks = [words[i:i + point] for i in range(0, len(words), point)]
176
+ new_string = "\n".join(" ".join(chunk) for chunk in word_chunks)
177
+ return new_string
libs/modules/visual/video.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+ from typing import Any, Union
6
+ import pathlib
7
+
8
+ import cv2
9
+
10
+
11
+ def create_video(num_iter: int,
12
+ save_dir: Union[Any, pathlib.Path],
13
+ video_frame_freq: int = 1,
14
+ fname: str = "rendering_process",
15
+ verbose: bool = True):
16
+ if not isinstance(save_dir, pathlib.Path):
17
+ save_dir = pathlib.Path(save_dir)
18
+
19
+ img_array = []
20
+ for i in range(0, num_iter):
21
+ if i % video_frame_freq == 0 or i == num_iter - 1:
22
+ filename = save_dir / f"iter{i}.png"
23
+ img = cv2.imread(filename.as_posix())
24
+ img_array.append(img)
25
+
26
+ video_name = save_dir / f"{fname}.mp4"
27
+ out = cv2.VideoWriter(
28
+ video_name.as_posix(),
29
+ cv2.VideoWriter_fourcc(*'mp4v'),
30
+ 30.0, # fps
31
+ (600, 600) # video size
32
+ )
33
+ for iii in range(len(img_array)):
34
+ out.write(img_array[iii])
35
+ out.release()
36
+
37
+ if verbose:
38
+ print(f"video saved in '{video_name}'.")
libs/solver/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
libs/solver/lr_scheduler.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch optimization for diffusion models."""
16
+
17
+ import math
18
+ from enum import Enum
19
+ from typing import Optional, Union
20
+
21
+ from torch.optim import Optimizer
22
+ from torch.optim.lr_scheduler import LambdaLR
23
+
24
+
25
+ class SchedulerType(Enum):
26
+ LINEAR = "linear"
27
+ COSINE = "cosine"
28
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
29
+ POLYNOMIAL = "polynomial"
30
+ CONSTANT = "constant"
31
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
32
+ PIECEWISE_CONSTANT = "piecewise_constant"
33
+
34
+
35
+ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
36
+ """
37
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
38
+
39
+ Args:
40
+ optimizer ([`~torch.optim.Optimizer`]):
41
+ The optimizer for which to schedule the learning rate.
42
+ last_epoch (`int`, *optional*, defaults to -1):
43
+ The index of the last epoch when resuming training.
44
+
45
+ Return:
46
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
47
+ """
48
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
49
+
50
+
51
+ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
52
+ """
53
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
54
+ increases linearly between 0 and the initial lr set in the optimizer.
55
+
56
+ Args:
57
+ optimizer ([`~torch.optim.Optimizer`]):
58
+ The optimizer for which to schedule the learning rate.
59
+ num_warmup_steps (`int`):
60
+ The number of steps for the warmup phase.
61
+ last_epoch (`int`, *optional*, defaults to -1):
62
+ The index of the last epoch when resuming training.
63
+
64
+ Return:
65
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
66
+ """
67
+
68
+ def lr_lambda(current_step: int):
69
+ if current_step < num_warmup_steps:
70
+ return float(current_step) / float(max(1.0, num_warmup_steps))
71
+ return 1.0
72
+
73
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
74
+
75
+
76
+ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
77
+ """
78
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
79
+
80
+ Args:
81
+ optimizer ([`~torch.optim.Optimizer`]):
82
+ The optimizer for which to schedule the learning rate.
83
+ step_rules (`string`):
84
+ The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
85
+ if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
86
+ steps and multiple 0.005 for the other steps.
87
+ last_epoch (`int`, *optional*, defaults to -1):
88
+ The index of the last epoch when resuming training.
89
+
90
+ Return:
91
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
92
+ """
93
+
94
+ rules_dict = {}
95
+ rule_list = step_rules.split(",")
96
+ for rule_str in rule_list[:-1]:
97
+ value_str, steps_str = rule_str.split(":")
98
+ steps = int(steps_str)
99
+ value = float(value_str)
100
+ rules_dict[steps] = value
101
+ last_lr_multiple = float(rule_list[-1])
102
+
103
+ def create_rules_function(rules_dict, last_lr_multiple):
104
+ def rule_func(steps: int) -> float:
105
+ sorted_steps = sorted(rules_dict.keys())
106
+ for i, sorted_step in enumerate(sorted_steps):
107
+ if steps < sorted_step:
108
+ return rules_dict[sorted_steps[i]]
109
+ return last_lr_multiple
110
+
111
+ return rule_func
112
+
113
+ rules_func = create_rules_function(rules_dict, last_lr_multiple)
114
+
115
+ return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
116
+
117
+
118
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
119
+ """
120
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
121
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
122
+
123
+ Args:
124
+ optimizer ([`~torch.optim.Optimizer`]):
125
+ The optimizer for which to schedule the learning rate.
126
+ num_warmup_steps (`int`):
127
+ The number of steps for the warmup phase.
128
+ num_training_steps (`int`):
129
+ The total number of training steps.
130
+ last_epoch (`int`, *optional*, defaults to -1):
131
+ The index of the last epoch when resuming training.
132
+
133
+ Return:
134
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
135
+ """
136
+
137
+ def lr_lambda(current_step: int):
138
+ if current_step < num_warmup_steps:
139
+ return float(current_step) / float(max(1, num_warmup_steps))
140
+ return max(
141
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
142
+ )
143
+
144
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
145
+
146
+
147
+ def get_cosine_schedule_with_warmup(
148
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5,
149
+ last_epoch: int = -1
150
+ ):
151
+ """
152
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
153
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
154
+ initial lr set in the optimizer.
155
+
156
+ Args:
157
+ optimizer ([`~torch.optim.Optimizer`]):
158
+ The optimizer for which to schedule the learning rate.
159
+ num_warmup_steps (`int`):
160
+ The number of steps for the warmup phase.
161
+ num_training_steps (`int`):
162
+ The total number of training steps.
163
+ num_periods (`float`, *optional*, defaults to 0.5):
164
+ The number of periods of the cosine function in a schedule (the default is to just decrease from the max
165
+ value to 0 following a half-cosine).
166
+ last_epoch (`int`, *optional*, defaults to -1):
167
+ The index of the last epoch when resuming training.
168
+
169
+ Return:
170
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
171
+ """
172
+
173
+ def lr_lambda(current_step):
174
+ if current_step < num_warmup_steps:
175
+ return float(current_step) / float(max(1, num_warmup_steps))
176
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
177
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
178
+
179
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
180
+
181
+
182
+ def get_cosine_with_hard_restarts_schedule_with_warmup(
183
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
184
+ ):
185
+ """
186
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
187
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
188
+ linearly between 0 and the initial lr set in the optimizer.
189
+
190
+ Args:
191
+ optimizer ([`~torch.optim.Optimizer`]):
192
+ The optimizer for which to schedule the learning rate.
193
+ num_warmup_steps (`int`):
194
+ The number of steps for the warmup phase.
195
+ num_training_steps (`int`):
196
+ The total number of training steps.
197
+ num_cycles (`int`, *optional*, defaults to 1):
198
+ The number of hard restarts to use.
199
+ last_epoch (`int`, *optional*, defaults to -1):
200
+ The index of the last epoch when resuming training.
201
+
202
+ Return:
203
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
204
+ """
205
+
206
+ def lr_lambda(current_step):
207
+ if current_step < num_warmup_steps:
208
+ return float(current_step) / float(max(1, num_warmup_steps))
209
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
210
+ if progress >= 1.0:
211
+ return 0.0
212
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
213
+
214
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
215
+
216
+
217
+ def get_polynomial_decay_schedule_with_warmup(
218
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
219
+ ):
220
+ """
221
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
222
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
223
+ initial lr set in the optimizer.
224
+
225
+ Args:
226
+ optimizer ([`~torch.optim.Optimizer`]):
227
+ The optimizer for which to schedule the learning rate.
228
+ num_warmup_steps (`int`):
229
+ The number of steps for the warmup phase.
230
+ num_training_steps (`int`):
231
+ The total number of training steps.
232
+ lr_end (`float`, *optional*, defaults to 1e-7):
233
+ The end LR.
234
+ power (`float`, *optional*, defaults to 1.0):
235
+ Power factor.
236
+ last_epoch (`int`, *optional*, defaults to -1):
237
+ The index of the last epoch when resuming training.
238
+
239
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
240
+ implementation at
241
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
242
+
243
+ Return:
244
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
245
+
246
+ """
247
+
248
+ lr_init = optimizer.defaults["lr"]
249
+ if not (lr_init > lr_end):
250
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
251
+
252
+ def lr_lambda(current_step: int):
253
+ if current_step < num_warmup_steps:
254
+ return float(current_step) / float(max(1, num_warmup_steps))
255
+ elif current_step > num_training_steps:
256
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
257
+ else:
258
+ lr_range = lr_init - lr_end
259
+ decay_steps = num_training_steps - num_warmup_steps
260
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
261
+ decay = lr_range * pct_remaining ** power + lr_end
262
+ return decay / lr_init # as LambdaLR multiplies by lr_init
263
+
264
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
265
+
266
+
267
+ TYPE_TO_SCHEDULER_FUNCTION = {
268
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
269
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
270
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
271
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
272
+ SchedulerType.CONSTANT: get_constant_schedule,
273
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
274
+ SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
275
+ }
276
+
277
+
278
+ def get_scheduler(
279
+ name: Union[str, SchedulerType],
280
+ optimizer: Optimizer,
281
+ step_rules: Optional[str] = None,
282
+ num_warmup_steps: Optional[int] = None,
283
+ num_training_steps: Optional[int] = None,
284
+ num_cycles: int = 1,
285
+ power: float = 1.0,
286
+ last_epoch: int = -1,
287
+ ):
288
+ """
289
+ Unified API to get any scheduler from its name.
290
+
291
+ Args:
292
+ name (`str` or `SchedulerType`):
293
+ The name of the scheduler to use.
294
+ optimizer (`torch.optim.Optimizer`):
295
+ The optimizer that will be used during training.
296
+ step_rules (`str`, *optional*):
297
+ A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
298
+ num_warmup_steps (`int`, *optional*):
299
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
300
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
301
+ num_training_steps (`int``, *optional*):
302
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
303
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
304
+ num_cycles (`int`, *optional*):
305
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
306
+ power (`float`, *optional*, defaults to 1.0):
307
+ Power factor. See `POLYNOMIAL` scheduler
308
+ last_epoch (`int`, *optional*, defaults to -1):
309
+ The index of the last epoch when resuming training.
310
+ """
311
+ name = SchedulerType(name)
312
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
313
+ if name == SchedulerType.CONSTANT:
314
+ return schedule_func(optimizer, last_epoch=last_epoch)
315
+
316
+ if name == SchedulerType.PIECEWISE_CONSTANT:
317
+ return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
318
+
319
+ # All other schedulers require `num_warmup_steps`
320
+ if num_warmup_steps is None:
321
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
322
+
323
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
324
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
325
+
326
+ # All other schedulers require `num_training_steps`
327
+ if num_training_steps is None:
328
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
329
+
330
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
331
+ return schedule_func(
332
+ optimizer,
333
+ num_warmup_steps=num_warmup_steps,
334
+ num_training_steps=num_training_steps,
335
+ num_cycles=num_cycles,
336
+ last_epoch=last_epoch,
337
+ )
338
+
339
+ if name == SchedulerType.POLYNOMIAL:
340
+ return schedule_func(
341
+ optimizer,
342
+ num_warmup_steps=num_warmup_steps,
343
+ num_training_steps=num_training_steps,
344
+ power=power,
345
+ last_epoch=last_epoch,
346
+ )
347
+
348
+ return schedule_func(
349
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
350
+ )