project upload
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- DeOldify/ImageColorizerArtisticTests.ipynb +3 -0
- DeOldify/deoldify/__init__.py +8 -0
- DeOldify/deoldify/_device.py +30 -0
- DeOldify/deoldify/augs.py +29 -0
- DeOldify/deoldify/critics.py +47 -0
- DeOldify/deoldify/dataset.py +46 -0
- DeOldify/deoldify/device_id.py +12 -0
- DeOldify/deoldify/filters.py +124 -0
- DeOldify/deoldify/generators.py +155 -0
- DeOldify/deoldify/layers.py +46 -0
- DeOldify/deoldify/loss.py +136 -0
- DeOldify/deoldify/save.py +29 -0
- DeOldify/deoldify/unet.py +285 -0
- DeOldify/deoldify/visualize.py +487 -0
- DeOldify/environment.yml +23 -0
- DeOldify/fastai/LICENSE +60 -0
- DeOldify/fastai/__init__.py +2 -0
- DeOldify/fastai/basic_data.py +279 -0
- DeOldify/fastai/basic_train.py +629 -0
- DeOldify/fastai/basics.py +29 -0
- DeOldify/fastai/callback.py +396 -0
- DeOldify/fastai/callbacks/__init__.py +11 -0
- DeOldify/fastai/callbacks/csv_logger.py +43 -0
- DeOldify/fastai/callbacks/fp16.py +122 -0
- DeOldify/fastai/callbacks/general_sched.py +46 -0
- DeOldify/fastai/callbacks/hooks.py +200 -0
- DeOldify/fastai/callbacks/loss_metrics.py +34 -0
- DeOldify/fastai/callbacks/lr_finder.py +39 -0
- DeOldify/fastai/callbacks/mem.py +65 -0
- DeOldify/fastai/callbacks/misc.py +12 -0
- DeOldify/fastai/callbacks/mixup.py +67 -0
- DeOldify/fastai/callbacks/mlflow.py +36 -0
- DeOldify/fastai/callbacks/one_cycle.py +58 -0
- DeOldify/fastai/callbacks/oversampling.py +22 -0
- DeOldify/fastai/callbacks/rnn.py +31 -0
- DeOldify/fastai/callbacks/tensorboard.py +427 -0
- DeOldify/fastai/callbacks/tracker.py +153 -0
- DeOldify/fastai/collab.py +106 -0
- DeOldify/fastai/core.py +535 -0
- DeOldify/fastai/data_block.py +804 -0
- DeOldify/fastai/datasets.py +234 -0
- DeOldify/fastai/distributed.py +119 -0
- DeOldify/fastai/gen_doc/__init__.py +1 -0
- DeOldify/fastai/gen_doc/autogen.tpl +10 -0
- DeOldify/fastai/gen_doc/convert2html.py +51 -0
- DeOldify/fastai/gen_doc/core.py +5 -0
- DeOldify/fastai/gen_doc/docstrings.py +142 -0
- DeOldify/fastai/gen_doc/doctest.py +133 -0
- DeOldify/fastai/gen_doc/gen_notebooks.py +350 -0
.gitattributes
CHANGED
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
35 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
36 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
34 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
35 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
36 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
DeOldify/ImageColorizerArtisticTests.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48a4d84968ac49079ccddd3834d35df0b3e0e544fc687feb1d7eda4539cd5fb4
|
3 |
+
size 508185940
|
DeOldify/deoldify/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import logging
|
3 |
+
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
4 |
+
logging.getLogger().setLevel(logging.INFO)
|
5 |
+
|
6 |
+
from deoldify._device import _Device
|
7 |
+
|
8 |
+
device = _Device()
|
DeOldify/deoldify/_device.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from enum import Enum
|
3 |
+
from .device_id import DeviceId
|
4 |
+
|
5 |
+
#NOTE: This must be called first before any torch imports in order to work properly!
|
6 |
+
|
7 |
+
class DeviceException(Exception):
|
8 |
+
pass
|
9 |
+
|
10 |
+
class _Device:
|
11 |
+
def __init__(self):
|
12 |
+
self.set(DeviceId.CPU)
|
13 |
+
|
14 |
+
def is_gpu(self):
|
15 |
+
''' Returns `True` if the current device is GPU, `False` otherwise. '''
|
16 |
+
return self.current() is not DeviceId.CPU
|
17 |
+
|
18 |
+
def current(self):
|
19 |
+
return self._current_device
|
20 |
+
|
21 |
+
def set(self, device:DeviceId):
|
22 |
+
if device == DeviceId.CPU:
|
23 |
+
os.environ['CUDA_VISIBLE_DEVICES']=''
|
24 |
+
else:
|
25 |
+
os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
|
26 |
+
import torch
|
27 |
+
torch.backends.cudnn.benchmark=False
|
28 |
+
|
29 |
+
self._current_device = device
|
30 |
+
return device
|
DeOldify/deoldify/augs.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
from fastai.vision.image import TfmPixel
|
4 |
+
|
5 |
+
# Contributed by Rani Horev. Thank you!
|
6 |
+
def _noisify(
|
7 |
+
x, pct_pixels_min: float = 0.001, pct_pixels_max: float = 0.4, noise_range: int = 30
|
8 |
+
):
|
9 |
+
if noise_range > 255 or noise_range < 0:
|
10 |
+
raise Exception("noise_range must be between 0 and 255, inclusively.")
|
11 |
+
|
12 |
+
h, w = x.shape[1:]
|
13 |
+
img_size = h * w
|
14 |
+
mult = 10000.0
|
15 |
+
pct_pixels = (
|
16 |
+
random.randrange(int(pct_pixels_min * mult), int(pct_pixels_max * mult)) / mult
|
17 |
+
)
|
18 |
+
noise_count = int(img_size * pct_pixels)
|
19 |
+
|
20 |
+
for ii in range(noise_count):
|
21 |
+
yy = random.randrange(h)
|
22 |
+
xx = random.randrange(w)
|
23 |
+
noise = random.randrange(-noise_range, noise_range) / 255.0
|
24 |
+
x[:, yy, xx].add_(noise)
|
25 |
+
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
noisify = TfmPixel(_noisify)
|
DeOldify/deoldify/critics.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basic_train import Learner
|
2 |
+
from fastai.core import *
|
3 |
+
from fastai.layers import NormType, conv_layer
|
4 |
+
from fastai.torch_core import *
|
5 |
+
from fastai.vision import *
|
6 |
+
from fastai.vision.data import ImageDataBunch
|
7 |
+
from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
|
8 |
+
|
9 |
+
_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
|
10 |
+
|
11 |
+
|
12 |
+
def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
|
13 |
+
return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
|
14 |
+
|
15 |
+
|
16 |
+
def custom_gan_critic(
|
17 |
+
n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
|
18 |
+
):
|
19 |
+
"Critic to train a `GAN`."
|
20 |
+
layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
|
21 |
+
for i in range(n_blocks):
|
22 |
+
layers += [
|
23 |
+
_conv(nf, nf, ks=3, stride=1),
|
24 |
+
nn.Dropout2d(p),
|
25 |
+
_conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
|
26 |
+
]
|
27 |
+
nf *= 2
|
28 |
+
layers += [
|
29 |
+
_conv(nf, nf, ks=3, stride=1),
|
30 |
+
_conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
|
31 |
+
Flatten(),
|
32 |
+
]
|
33 |
+
return nn.Sequential(*layers)
|
34 |
+
|
35 |
+
|
36 |
+
def colorize_crit_learner(
|
37 |
+
data: ImageDataBunch,
|
38 |
+
loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
|
39 |
+
nf: int = 256,
|
40 |
+
) -> Learner:
|
41 |
+
return Learner(
|
42 |
+
data,
|
43 |
+
custom_gan_critic(nf=nf),
|
44 |
+
metrics=accuracy_thresh_expand,
|
45 |
+
loss_func=loss_critic,
|
46 |
+
wd=1e-3,
|
47 |
+
)
|
DeOldify/deoldify/dataset.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai import *
|
2 |
+
from fastai.core import *
|
3 |
+
from fastai.vision.transform import get_transforms
|
4 |
+
from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
|
5 |
+
|
6 |
+
|
7 |
+
def get_colorize_data(
|
8 |
+
sz: int,
|
9 |
+
bs: int,
|
10 |
+
crappy_path: Path,
|
11 |
+
good_path: Path,
|
12 |
+
random_seed: int = None,
|
13 |
+
keep_pct: float = 1.0,
|
14 |
+
num_workers: int = 8,
|
15 |
+
stats: tuple = imagenet_stats,
|
16 |
+
xtra_tfms=[],
|
17 |
+
) -> ImageDataBunch:
|
18 |
+
|
19 |
+
src = (
|
20 |
+
ImageImageList.from_folder(crappy_path, convert_mode='RGB')
|
21 |
+
.use_partial_data(sample_pct=keep_pct, seed=random_seed)
|
22 |
+
.split_by_rand_pct(0.1, seed=random_seed)
|
23 |
+
)
|
24 |
+
|
25 |
+
data = (
|
26 |
+
src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
|
27 |
+
.transform(
|
28 |
+
get_transforms(
|
29 |
+
max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
|
30 |
+
),
|
31 |
+
size=sz,
|
32 |
+
tfm_y=True,
|
33 |
+
)
|
34 |
+
.databunch(bs=bs, num_workers=num_workers, no_check=True)
|
35 |
+
.normalize(stats, do_y=True)
|
36 |
+
)
|
37 |
+
|
38 |
+
data.c = 3
|
39 |
+
return data
|
40 |
+
|
41 |
+
|
42 |
+
def get_dummy_databunch() -> ImageDataBunch:
|
43 |
+
path = Path('./dummy/')
|
44 |
+
return get_colorize_data(
|
45 |
+
sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
|
46 |
+
)
|
DeOldify/deoldify/device_id.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import IntEnum
|
2 |
+
|
3 |
+
class DeviceId(IntEnum):
|
4 |
+
GPU0 = 0,
|
5 |
+
GPU1 = 1,
|
6 |
+
GPU2 = 2,
|
7 |
+
GPU3 = 3,
|
8 |
+
GPU4 = 4,
|
9 |
+
GPU5 = 5,
|
10 |
+
GPU6 = 6,
|
11 |
+
GPU7 = 7,
|
12 |
+
CPU = 99
|
DeOldify/deoldify/filters.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basic_data import DatasetType
|
2 |
+
from fastai.basic_train import Learner
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from fastai.core import *
|
5 |
+
from fastai.vision import *
|
6 |
+
from fastai.vision.image import *
|
7 |
+
from fastai.vision.data import *
|
8 |
+
from fastai import *
|
9 |
+
import cv2
|
10 |
+
from PIL import Image as PilImage
|
11 |
+
from deoldify import device as device_settings
|
12 |
+
import logging
|
13 |
+
|
14 |
+
|
15 |
+
class IFilter(ABC):
|
16 |
+
@abstractmethod
|
17 |
+
def filter(
|
18 |
+
self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
|
19 |
+
) -> PilImage:
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
class BaseFilter(IFilter):
|
24 |
+
def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
|
25 |
+
super().__init__()
|
26 |
+
self.learn = learn
|
27 |
+
|
28 |
+
if not device_settings.is_gpu():
|
29 |
+
self.learn.model = self.learn.model.cpu()
|
30 |
+
|
31 |
+
self.device = next(self.learn.model.parameters()).device
|
32 |
+
self.norm, self.denorm = normalize_funcs(*stats)
|
33 |
+
|
34 |
+
def _transform(self, image: PilImage) -> PilImage:
|
35 |
+
return image
|
36 |
+
|
37 |
+
def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:
|
38 |
+
# a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
|
39 |
+
# I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
|
40 |
+
targ_sz = (targ, targ)
|
41 |
+
return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
|
42 |
+
|
43 |
+
def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:
|
44 |
+
result = self._scale_to_square(orig, sz)
|
45 |
+
result = self._transform(result)
|
46 |
+
return result
|
47 |
+
|
48 |
+
def _model_process(self, orig: PilImage, sz: int) -> PilImage:
|
49 |
+
model_image = self._get_model_ready_image(orig, sz)
|
50 |
+
x = pil2tensor(model_image, np.float32)
|
51 |
+
x = x.to(self.device)
|
52 |
+
x.div_(255)
|
53 |
+
x, y = self.norm((x, x), do_x=True)
|
54 |
+
|
55 |
+
try:
|
56 |
+
result = self.learn.pred_batch(
|
57 |
+
ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
|
58 |
+
)
|
59 |
+
except RuntimeError as rerr:
|
60 |
+
if 'memory' not in str(rerr):
|
61 |
+
raise rerr
|
62 |
+
logging.warn('Warning: render_factor was set too high, and out of memory error resulted. Returning original image.')
|
63 |
+
return model_image
|
64 |
+
|
65 |
+
out = result[0]
|
66 |
+
out = self.denorm(out.px, do_x=False)
|
67 |
+
out = image2np(out * 255).astype(np.uint8)
|
68 |
+
return PilImage.fromarray(out)
|
69 |
+
|
70 |
+
def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:
|
71 |
+
targ_sz = orig.size
|
72 |
+
image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
|
73 |
+
return image
|
74 |
+
|
75 |
+
|
76 |
+
class ColorizerFilter(BaseFilter):
|
77 |
+
def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
|
78 |
+
super().__init__(learn=learn, stats=stats)
|
79 |
+
self.render_base = 16
|
80 |
+
|
81 |
+
def filter(
|
82 |
+
self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
|
83 |
+
render_sz = render_factor * self.render_base
|
84 |
+
model_image = self._model_process(orig=filtered_image, sz=render_sz)
|
85 |
+
raw_color = self._unsquare(model_image, orig_image)
|
86 |
+
|
87 |
+
if post_process:
|
88 |
+
return self._post_process(raw_color, orig_image)
|
89 |
+
else:
|
90 |
+
return raw_color
|
91 |
+
|
92 |
+
def _transform(self, image: PilImage) -> PilImage:
|
93 |
+
return image.convert('LA').convert('RGB')
|
94 |
+
|
95 |
+
# This takes advantage of the fact that human eyes are much less sensitive to
|
96 |
+
# imperfections in chrominance compared to luminance. This means we can
|
97 |
+
# save a lot on memory and processing in the model, yet get a great high
|
98 |
+
# resolution result at the end. This is primarily intended just for
|
99 |
+
# inference
|
100 |
+
def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
|
101 |
+
color_np = np.asarray(raw_color)
|
102 |
+
orig_np = np.asarray(orig)
|
103 |
+
color_yuv = cv2.cvtColor(color_np, cv2.COLOR_RGB2YUV)
|
104 |
+
# do a black and white transform first to get better luminance values
|
105 |
+
orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_RGB2YUV)
|
106 |
+
hires = np.copy(orig_yuv)
|
107 |
+
hires[:, :, 1:3] = color_yuv[:, :, 1:3]
|
108 |
+
final = cv2.cvtColor(hires, cv2.COLOR_YUV2RGB)
|
109 |
+
final = PilImage.fromarray(final)
|
110 |
+
return final
|
111 |
+
|
112 |
+
|
113 |
+
class MasterFilter(BaseFilter):
|
114 |
+
def __init__(self, filters: List[IFilter], render_factor: int):
|
115 |
+
self.filters = filters
|
116 |
+
self.render_factor = render_factor
|
117 |
+
|
118 |
+
def filter(
|
119 |
+
self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
|
120 |
+
render_factor = self.render_factor if render_factor is None else render_factor
|
121 |
+
for filter in self.filters:
|
122 |
+
filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)
|
123 |
+
|
124 |
+
return filtered_image
|
DeOldify/deoldify/generators.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basic_data import DataBunch
|
2 |
+
from fastai.basic_train import Learner
|
3 |
+
from fastai.layers import NormType
|
4 |
+
from fastai.torch_core import SplitFuncOrIdxList, apply_init, to_device
|
5 |
+
from fastai.vision import *
|
6 |
+
from fastai.vision.learner import cnn_config, create_body
|
7 |
+
from torch import nn
|
8 |
+
from .unet import DynamicUnetWide, DynamicUnetDeep
|
9 |
+
from .dataset import *
|
10 |
+
|
11 |
+
# Weights are implicitly read from ./models/ folder
|
12 |
+
def gen_inference_wide(
|
13 |
+
root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:
|
14 |
+
data = get_dummy_databunch()
|
15 |
+
learn = gen_learner_wide(
|
16 |
+
data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
|
17 |
+
)
|
18 |
+
learn.path = root_folder
|
19 |
+
learn.load(weights_name)
|
20 |
+
learn.model.eval()
|
21 |
+
return learn
|
22 |
+
|
23 |
+
|
24 |
+
def gen_learner_wide(
|
25 |
+
data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2
|
26 |
+
) -> Learner:
|
27 |
+
return unet_learner_wide(
|
28 |
+
data,
|
29 |
+
arch=arch,
|
30 |
+
wd=1e-3,
|
31 |
+
blur=True,
|
32 |
+
norm_type=NormType.Spectral,
|
33 |
+
self_attention=True,
|
34 |
+
y_range=(-3.0, 3.0),
|
35 |
+
loss_func=gen_loss,
|
36 |
+
nf_factor=nf_factor,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
41 |
+
def unet_learner_wide(
|
42 |
+
data: DataBunch,
|
43 |
+
arch: Callable,
|
44 |
+
pretrained: bool = True,
|
45 |
+
blur_final: bool = True,
|
46 |
+
norm_type: Optional[NormType] = NormType,
|
47 |
+
split_on: Optional[SplitFuncOrIdxList] = None,
|
48 |
+
blur: bool = False,
|
49 |
+
self_attention: bool = False,
|
50 |
+
y_range: Optional[Tuple[float, float]] = None,
|
51 |
+
last_cross: bool = True,
|
52 |
+
bottle: bool = False,
|
53 |
+
nf_factor: int = 1,
|
54 |
+
**kwargs: Any
|
55 |
+
) -> Learner:
|
56 |
+
"Build Unet learner from `data` and `arch`."
|
57 |
+
meta = cnn_config(arch)
|
58 |
+
body = create_body(arch, pretrained)
|
59 |
+
model = to_device(
|
60 |
+
DynamicUnetWide(
|
61 |
+
body,
|
62 |
+
n_classes=data.c,
|
63 |
+
blur=blur,
|
64 |
+
blur_final=blur_final,
|
65 |
+
self_attention=self_attention,
|
66 |
+
y_range=y_range,
|
67 |
+
norm_type=norm_type,
|
68 |
+
last_cross=last_cross,
|
69 |
+
bottle=bottle,
|
70 |
+
nf_factor=nf_factor,
|
71 |
+
),
|
72 |
+
data.device,
|
73 |
+
)
|
74 |
+
learn = Learner(data, model, **kwargs)
|
75 |
+
learn.split(ifnone(split_on, meta['split']))
|
76 |
+
if pretrained:
|
77 |
+
learn.freeze()
|
78 |
+
apply_init(model[2], nn.init.kaiming_normal_)
|
79 |
+
return learn
|
80 |
+
|
81 |
+
|
82 |
+
# ----------------------------------------------------------------------
|
83 |
+
|
84 |
+
# Weights are implicitly read from ./models/ folder
|
85 |
+
def gen_inference_deep(
|
86 |
+
root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:
|
87 |
+
data = get_dummy_databunch()
|
88 |
+
learn = gen_learner_deep(
|
89 |
+
data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
|
90 |
+
)
|
91 |
+
learn.path = root_folder
|
92 |
+
learn.load(weights_name)
|
93 |
+
learn.model.eval()
|
94 |
+
return learn
|
95 |
+
|
96 |
+
|
97 |
+
def gen_learner_deep(
|
98 |
+
data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5
|
99 |
+
) -> Learner:
|
100 |
+
return unet_learner_deep(
|
101 |
+
data,
|
102 |
+
arch,
|
103 |
+
wd=1e-3,
|
104 |
+
blur=True,
|
105 |
+
norm_type=NormType.Spectral,
|
106 |
+
self_attention=True,
|
107 |
+
y_range=(-3.0, 3.0),
|
108 |
+
loss_func=gen_loss,
|
109 |
+
nf_factor=nf_factor,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
114 |
+
def unet_learner_deep(
|
115 |
+
data: DataBunch,
|
116 |
+
arch: Callable,
|
117 |
+
pretrained: bool = True,
|
118 |
+
blur_final: bool = True,
|
119 |
+
norm_type: Optional[NormType] = NormType,
|
120 |
+
split_on: Optional[SplitFuncOrIdxList] = None,
|
121 |
+
blur: bool = False,
|
122 |
+
self_attention: bool = False,
|
123 |
+
y_range: Optional[Tuple[float, float]] = None,
|
124 |
+
last_cross: bool = True,
|
125 |
+
bottle: bool = False,
|
126 |
+
nf_factor: float = 1.5,
|
127 |
+
**kwargs: Any
|
128 |
+
) -> Learner:
|
129 |
+
"Build Unet learner from `data` and `arch`."
|
130 |
+
meta = cnn_config(arch)
|
131 |
+
body = create_body(arch, pretrained)
|
132 |
+
model = to_device(
|
133 |
+
DynamicUnetDeep(
|
134 |
+
body,
|
135 |
+
n_classes=data.c,
|
136 |
+
blur=blur,
|
137 |
+
blur_final=blur_final,
|
138 |
+
self_attention=self_attention,
|
139 |
+
y_range=y_range,
|
140 |
+
norm_type=norm_type,
|
141 |
+
last_cross=last_cross,
|
142 |
+
bottle=bottle,
|
143 |
+
nf_factor=nf_factor,
|
144 |
+
),
|
145 |
+
data.device,
|
146 |
+
)
|
147 |
+
learn = Learner(data, model, **kwargs)
|
148 |
+
learn.split(ifnone(split_on, meta['split']))
|
149 |
+
if pretrained:
|
150 |
+
learn.freeze()
|
151 |
+
apply_init(model[2], nn.init.kaiming_normal_)
|
152 |
+
return learn
|
153 |
+
|
154 |
+
|
155 |
+
# -----------------------------
|
DeOldify/deoldify/layers.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.layers import *
|
2 |
+
from fastai.torch_core import *
|
3 |
+
|
4 |
+
|
5 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
6 |
+
|
7 |
+
|
8 |
+
def custom_conv_layer(
|
9 |
+
ni: int,
|
10 |
+
nf: int,
|
11 |
+
ks: int = 3,
|
12 |
+
stride: int = 1,
|
13 |
+
padding: int = None,
|
14 |
+
bias: bool = None,
|
15 |
+
is_1d: bool = False,
|
16 |
+
norm_type: Optional[NormType] = NormType.Batch,
|
17 |
+
use_activ: bool = True,
|
18 |
+
leaky: float = None,
|
19 |
+
transpose: bool = False,
|
20 |
+
init: Callable = nn.init.kaiming_normal_,
|
21 |
+
self_attention: bool = False,
|
22 |
+
extra_bn: bool = False,
|
23 |
+
):
|
24 |
+
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
|
25 |
+
if padding is None:
|
26 |
+
padding = (ks - 1) // 2 if not transpose else 0
|
27 |
+
bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
|
28 |
+
if bias is None:
|
29 |
+
bias = not bn
|
30 |
+
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
|
31 |
+
conv = init_default(
|
32 |
+
conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
|
33 |
+
init,
|
34 |
+
)
|
35 |
+
if norm_type == NormType.Weight:
|
36 |
+
conv = weight_norm(conv)
|
37 |
+
elif norm_type == NormType.Spectral:
|
38 |
+
conv = spectral_norm(conv)
|
39 |
+
layers = [conv]
|
40 |
+
if use_activ:
|
41 |
+
layers.append(relu(True, leaky=leaky))
|
42 |
+
if bn:
|
43 |
+
layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
|
44 |
+
if self_attention:
|
45 |
+
layers.append(SelfAttention(nf))
|
46 |
+
return nn.Sequential(*layers)
|
DeOldify/deoldify/loss.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai import *
|
2 |
+
from fastai.core import *
|
3 |
+
from fastai.torch_core import *
|
4 |
+
from fastai.callbacks import hook_outputs
|
5 |
+
import torchvision.models as models
|
6 |
+
|
7 |
+
|
8 |
+
class FeatureLoss(nn.Module):
|
9 |
+
def __init__(self, layer_wgts=[20, 70, 10]):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.m_feat = models.vgg16_bn(True).features.cuda().eval()
|
13 |
+
requires_grad(self.m_feat, False)
|
14 |
+
blocks = [
|
15 |
+
i - 1
|
16 |
+
for i, o in enumerate(children(self.m_feat))
|
17 |
+
if isinstance(o, nn.MaxPool2d)
|
18 |
+
]
|
19 |
+
layer_ids = blocks[2:5]
|
20 |
+
self.loss_features = [self.m_feat[i] for i in layer_ids]
|
21 |
+
self.hooks = hook_outputs(self.loss_features, detach=False)
|
22 |
+
self.wgts = layer_wgts
|
23 |
+
self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]
|
24 |
+
self.base_loss = F.l1_loss
|
25 |
+
|
26 |
+
def _make_features(self, x, clone=False):
|
27 |
+
self.m_feat(x)
|
28 |
+
return [(o.clone() if clone else o) for o in self.hooks.stored]
|
29 |
+
|
30 |
+
def forward(self, input, target):
|
31 |
+
out_feat = self._make_features(target, clone=True)
|
32 |
+
in_feat = self._make_features(input)
|
33 |
+
self.feat_losses = [self.base_loss(input, target)]
|
34 |
+
self.feat_losses += [
|
35 |
+
self.base_loss(f_in, f_out) * w
|
36 |
+
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
|
37 |
+
]
|
38 |
+
|
39 |
+
self.metrics = dict(zip(self.metric_names, self.feat_losses))
|
40 |
+
return sum(self.feat_losses)
|
41 |
+
|
42 |
+
def __del__(self):
|
43 |
+
self.hooks.remove()
|
44 |
+
|
45 |
+
|
46 |
+
# Refactored code, originally from https://github.com/VinceMarron/style_transfer
|
47 |
+
class WassFeatureLoss(nn.Module):
|
48 |
+
def __init__(self, layer_wgts=[5, 15, 2], wass_wgts=[3.0, 0.7, 0.01]):
|
49 |
+
super().__init__()
|
50 |
+
self.m_feat = models.vgg16_bn(True).features.cuda().eval()
|
51 |
+
requires_grad(self.m_feat, False)
|
52 |
+
blocks = [
|
53 |
+
i - 1
|
54 |
+
for i, o in enumerate(children(self.m_feat))
|
55 |
+
if isinstance(o, nn.MaxPool2d)
|
56 |
+
]
|
57 |
+
layer_ids = blocks[2:5]
|
58 |
+
self.loss_features = [self.m_feat[i] for i in layer_ids]
|
59 |
+
self.hooks = hook_outputs(self.loss_features, detach=False)
|
60 |
+
self.wgts = layer_wgts
|
61 |
+
self.wass_wgts = wass_wgts
|
62 |
+
self.metric_names = (
|
63 |
+
['pixel']
|
64 |
+
+ [f'feat_{i}' for i in range(len(layer_ids))]
|
65 |
+
+ [f'wass_{i}' for i in range(len(layer_ids))]
|
66 |
+
)
|
67 |
+
self.base_loss = F.l1_loss
|
68 |
+
|
69 |
+
def _make_features(self, x, clone=False):
|
70 |
+
self.m_feat(x)
|
71 |
+
return [(o.clone() if clone else o) for o in self.hooks.stored]
|
72 |
+
|
73 |
+
def _calc_2_moments(self, tensor):
|
74 |
+
chans = tensor.shape[1]
|
75 |
+
tensor = tensor.view(1, chans, -1)
|
76 |
+
n = tensor.shape[2]
|
77 |
+
mu = tensor.mean(2)
|
78 |
+
tensor = (tensor - mu[:, :, None]).squeeze(0)
|
79 |
+
# Prevents nasty bug that happens very occassionally- divide by zero. Why such things happen?
|
80 |
+
if n == 0:
|
81 |
+
return None, None
|
82 |
+
cov = torch.mm(tensor, tensor.t()) / float(n)
|
83 |
+
return mu, cov
|
84 |
+
|
85 |
+
def _get_style_vals(self, tensor):
|
86 |
+
mean, cov = self._calc_2_moments(tensor)
|
87 |
+
if mean is None:
|
88 |
+
return None, None, None
|
89 |
+
eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
|
90 |
+
eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))
|
91 |
+
root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())
|
92 |
+
tr_cov = eigvals.clamp(min=0).sum()
|
93 |
+
return mean, tr_cov, root_cov
|
94 |
+
|
95 |
+
def _calc_l2wass_dist(
|
96 |
+
self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth
|
97 |
+
):
|
98 |
+
tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
|
99 |
+
mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
|
100 |
+
cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
|
101 |
+
var_overlap = torch.sqrt(
|
102 |
+
torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0) + 1e-8
|
103 |
+
).sum()
|
104 |
+
dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap
|
105 |
+
return dist
|
106 |
+
|
107 |
+
def _single_wass_loss(self, pred, targ):
|
108 |
+
mean_test, tr_cov_test, root_cov_test = targ
|
109 |
+
mean_synth, cov_synth = self._calc_2_moments(pred)
|
110 |
+
loss = self._calc_l2wass_dist(
|
111 |
+
mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth
|
112 |
+
)
|
113 |
+
return loss
|
114 |
+
|
115 |
+
def forward(self, input, target):
|
116 |
+
out_feat = self._make_features(target, clone=True)
|
117 |
+
in_feat = self._make_features(input)
|
118 |
+
self.feat_losses = [self.base_loss(input, target)]
|
119 |
+
self.feat_losses += [
|
120 |
+
self.base_loss(f_in, f_out) * w
|
121 |
+
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
|
122 |
+
]
|
123 |
+
|
124 |
+
styles = [self._get_style_vals(i) for i in out_feat]
|
125 |
+
|
126 |
+
if styles[0][0] is not None:
|
127 |
+
self.feat_losses += [
|
128 |
+
self._single_wass_loss(f_pred, f_targ) * w
|
129 |
+
for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)
|
130 |
+
]
|
131 |
+
|
132 |
+
self.metrics = dict(zip(self.metric_names, self.feat_losses))
|
133 |
+
return sum(self.feat_losses)
|
134 |
+
|
135 |
+
def __del__(self):
|
136 |
+
self.hooks.remove()
|
DeOldify/deoldify/save.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basic_train import Learner, LearnerCallback
|
2 |
+
from fastai.vision.gan import GANLearner
|
3 |
+
|
4 |
+
|
5 |
+
class GANSaveCallback(LearnerCallback):
|
6 |
+
"""A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
learn: GANLearner,
|
11 |
+
learn_gen: Learner,
|
12 |
+
filename: str,
|
13 |
+
save_iters: int = 1000,
|
14 |
+
):
|
15 |
+
super().__init__(learn)
|
16 |
+
self.learn_gen = learn_gen
|
17 |
+
self.filename = filename
|
18 |
+
self.save_iters = save_iters
|
19 |
+
|
20 |
+
def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
|
21 |
+
if iteration == 0:
|
22 |
+
return
|
23 |
+
|
24 |
+
if iteration % self.save_iters == 0:
|
25 |
+
self._save_gen_learner(iteration=iteration, epoch=epoch)
|
26 |
+
|
27 |
+
def _save_gen_learner(self, iteration: int, epoch: int):
|
28 |
+
filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
|
29 |
+
self.learn_gen.save(filename)
|
DeOldify/deoldify/unet.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.layers import *
|
2 |
+
from .layers import *
|
3 |
+
from fastai.torch_core import *
|
4 |
+
from fastai.callbacks.hooks import *
|
5 |
+
from fastai.vision import *
|
6 |
+
|
7 |
+
|
8 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
9 |
+
|
10 |
+
__all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
|
11 |
+
|
12 |
+
|
13 |
+
def _get_sfs_idxs(sizes: Sizes) -> List[int]:
|
14 |
+
"Get the indexes of the layers where the size of the activation changes."
|
15 |
+
feature_szs = [size[-1] for size in sizes]
|
16 |
+
sfs_idxs = list(
|
17 |
+
np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]
|
18 |
+
)
|
19 |
+
if feature_szs[0] != feature_szs[1]:
|
20 |
+
sfs_idxs = [0] + sfs_idxs
|
21 |
+
return sfs_idxs
|
22 |
+
|
23 |
+
|
24 |
+
class CustomPixelShuffle_ICNR(nn.Module):
|
25 |
+
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
ni: int,
|
30 |
+
nf: int = None,
|
31 |
+
scale: int = 2,
|
32 |
+
blur: bool = False,
|
33 |
+
leaky: float = None,
|
34 |
+
**kwargs
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
nf = ifnone(nf, ni)
|
38 |
+
self.conv = custom_conv_layer(
|
39 |
+
ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs
|
40 |
+
)
|
41 |
+
icnr(self.conv[0].weight)
|
42 |
+
self.shuf = nn.PixelShuffle(scale)
|
43 |
+
# Blurring over (h*w) kernel
|
44 |
+
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
|
45 |
+
# - https://arxiv.org/abs/1806.02658
|
46 |
+
self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
|
47 |
+
self.blur = nn.AvgPool2d(2, stride=1)
|
48 |
+
self.relu = relu(True, leaky=leaky)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = self.shuf(self.relu(self.conv(x)))
|
52 |
+
return self.blur(self.pad(x)) if self.blur else x
|
53 |
+
|
54 |
+
|
55 |
+
class UnetBlockDeep(nn.Module):
|
56 |
+
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
up_in_c: int,
|
61 |
+
x_in_c: int,
|
62 |
+
hook: Hook,
|
63 |
+
final_div: bool = True,
|
64 |
+
blur: bool = False,
|
65 |
+
leaky: float = None,
|
66 |
+
self_attention: bool = False,
|
67 |
+
nf_factor: float = 1.0,
|
68 |
+
**kwargs
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
self.hook = hook
|
72 |
+
self.shuf = CustomPixelShuffle_ICNR(
|
73 |
+
up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs
|
74 |
+
)
|
75 |
+
self.bn = batchnorm_2d(x_in_c)
|
76 |
+
ni = up_in_c // 2 + x_in_c
|
77 |
+
nf = int((ni if final_div else ni // 2) * nf_factor)
|
78 |
+
self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
|
79 |
+
self.conv2 = custom_conv_layer(
|
80 |
+
nf, nf, leaky=leaky, self_attention=self_attention, **kwargs
|
81 |
+
)
|
82 |
+
self.relu = relu(leaky=leaky)
|
83 |
+
|
84 |
+
def forward(self, up_in: Tensor) -> Tensor:
|
85 |
+
s = self.hook.stored
|
86 |
+
up_out = self.shuf(up_in)
|
87 |
+
ssh = s.shape[-2:]
|
88 |
+
if ssh != up_out.shape[-2:]:
|
89 |
+
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
|
90 |
+
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
|
91 |
+
return self.conv2(self.conv1(cat_x))
|
92 |
+
|
93 |
+
|
94 |
+
class DynamicUnetDeep(SequentialEx):
|
95 |
+
"Create a U-Net from a given architecture."
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
encoder: nn.Module,
|
100 |
+
n_classes: int,
|
101 |
+
blur: bool = False,
|
102 |
+
blur_final=True,
|
103 |
+
self_attention: bool = False,
|
104 |
+
y_range: Optional[Tuple[float, float]] = None,
|
105 |
+
last_cross: bool = True,
|
106 |
+
bottle: bool = False,
|
107 |
+
norm_type: Optional[NormType] = NormType.Batch,
|
108 |
+
nf_factor: float = 1.0,
|
109 |
+
**kwargs
|
110 |
+
):
|
111 |
+
extra_bn = norm_type == NormType.Spectral
|
112 |
+
imsize = (256, 256)
|
113 |
+
sfs_szs = model_sizes(encoder, size=imsize)
|
114 |
+
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
|
115 |
+
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
|
116 |
+
x = dummy_eval(encoder, imsize).detach()
|
117 |
+
|
118 |
+
ni = sfs_szs[-1][1]
|
119 |
+
middle_conv = nn.Sequential(
|
120 |
+
custom_conv_layer(
|
121 |
+
ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
122 |
+
),
|
123 |
+
custom_conv_layer(
|
124 |
+
ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
125 |
+
),
|
126 |
+
).eval()
|
127 |
+
x = middle_conv(x)
|
128 |
+
layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
|
129 |
+
|
130 |
+
for i, idx in enumerate(sfs_idxs):
|
131 |
+
not_final = i != len(sfs_idxs) - 1
|
132 |
+
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
|
133 |
+
do_blur = blur and (not_final or blur_final)
|
134 |
+
sa = self_attention and (i == len(sfs_idxs) - 3)
|
135 |
+
unet_block = UnetBlockDeep(
|
136 |
+
up_in_c,
|
137 |
+
x_in_c,
|
138 |
+
self.sfs[i],
|
139 |
+
final_div=not_final,
|
140 |
+
blur=blur,
|
141 |
+
self_attention=sa,
|
142 |
+
norm_type=norm_type,
|
143 |
+
extra_bn=extra_bn,
|
144 |
+
nf_factor=nf_factor,
|
145 |
+
**kwargs
|
146 |
+
).eval()
|
147 |
+
layers.append(unet_block)
|
148 |
+
x = unet_block(x)
|
149 |
+
|
150 |
+
ni = x.shape[1]
|
151 |
+
if imsize != sfs_szs[0][-2:]:
|
152 |
+
layers.append(PixelShuffle_ICNR(ni, **kwargs))
|
153 |
+
if last_cross:
|
154 |
+
layers.append(MergeLayer(dense=True))
|
155 |
+
ni += in_channels(encoder)
|
156 |
+
layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
|
157 |
+
layers += [
|
158 |
+
custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
|
159 |
+
]
|
160 |
+
if y_range is not None:
|
161 |
+
layers.append(SigmoidRange(*y_range))
|
162 |
+
super().__init__(*layers)
|
163 |
+
|
164 |
+
def __del__(self):
|
165 |
+
if hasattr(self, "sfs"):
|
166 |
+
self.sfs.remove()
|
167 |
+
|
168 |
+
|
169 |
+
# ------------------------------------------------------
|
170 |
+
class UnetBlockWide(nn.Module):
|
171 |
+
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
|
172 |
+
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
up_in_c: int,
|
176 |
+
x_in_c: int,
|
177 |
+
n_out: int,
|
178 |
+
hook: Hook,
|
179 |
+
final_div: bool = True,
|
180 |
+
blur: bool = False,
|
181 |
+
leaky: float = None,
|
182 |
+
self_attention: bool = False,
|
183 |
+
**kwargs
|
184 |
+
):
|
185 |
+
super().__init__()
|
186 |
+
self.hook = hook
|
187 |
+
up_out = x_out = n_out // 2
|
188 |
+
self.shuf = CustomPixelShuffle_ICNR(
|
189 |
+
up_in_c, up_out, blur=blur, leaky=leaky, **kwargs
|
190 |
+
)
|
191 |
+
self.bn = batchnorm_2d(x_in_c)
|
192 |
+
ni = up_out + x_in_c
|
193 |
+
self.conv = custom_conv_layer(
|
194 |
+
ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs
|
195 |
+
)
|
196 |
+
self.relu = relu(leaky=leaky)
|
197 |
+
|
198 |
+
def forward(self, up_in: Tensor) -> Tensor:
|
199 |
+
s = self.hook.stored
|
200 |
+
up_out = self.shuf(up_in)
|
201 |
+
ssh = s.shape[-2:]
|
202 |
+
if ssh != up_out.shape[-2:]:
|
203 |
+
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
|
204 |
+
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
|
205 |
+
return self.conv(cat_x)
|
206 |
+
|
207 |
+
|
208 |
+
class DynamicUnetWide(SequentialEx):
|
209 |
+
"Create a U-Net from a given architecture."
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
encoder: nn.Module,
|
214 |
+
n_classes: int,
|
215 |
+
blur: bool = False,
|
216 |
+
blur_final=True,
|
217 |
+
self_attention: bool = False,
|
218 |
+
y_range: Optional[Tuple[float, float]] = None,
|
219 |
+
last_cross: bool = True,
|
220 |
+
bottle: bool = False,
|
221 |
+
norm_type: Optional[NormType] = NormType.Batch,
|
222 |
+
nf_factor: int = 1,
|
223 |
+
**kwargs
|
224 |
+
):
|
225 |
+
|
226 |
+
nf = 512 * nf_factor
|
227 |
+
extra_bn = norm_type == NormType.Spectral
|
228 |
+
imsize = (256, 256)
|
229 |
+
sfs_szs = model_sizes(encoder, size=imsize)
|
230 |
+
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
|
231 |
+
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
|
232 |
+
x = dummy_eval(encoder, imsize).detach()
|
233 |
+
|
234 |
+
ni = sfs_szs[-1][1]
|
235 |
+
middle_conv = nn.Sequential(
|
236 |
+
custom_conv_layer(
|
237 |
+
ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
238 |
+
),
|
239 |
+
custom_conv_layer(
|
240 |
+
ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
241 |
+
),
|
242 |
+
).eval()
|
243 |
+
x = middle_conv(x)
|
244 |
+
layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
|
245 |
+
|
246 |
+
for i, idx in enumerate(sfs_idxs):
|
247 |
+
not_final = i != len(sfs_idxs) - 1
|
248 |
+
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
|
249 |
+
do_blur = blur and (not_final or blur_final)
|
250 |
+
sa = self_attention and (i == len(sfs_idxs) - 3)
|
251 |
+
|
252 |
+
n_out = nf if not_final else nf // 2
|
253 |
+
|
254 |
+
unet_block = UnetBlockWide(
|
255 |
+
up_in_c,
|
256 |
+
x_in_c,
|
257 |
+
n_out,
|
258 |
+
self.sfs[i],
|
259 |
+
final_div=not_final,
|
260 |
+
blur=blur,
|
261 |
+
self_attention=sa,
|
262 |
+
norm_type=norm_type,
|
263 |
+
extra_bn=extra_bn,
|
264 |
+
**kwargs
|
265 |
+
).eval()
|
266 |
+
layers.append(unet_block)
|
267 |
+
x = unet_block(x)
|
268 |
+
|
269 |
+
ni = x.shape[1]
|
270 |
+
if imsize != sfs_szs[0][-2:]:
|
271 |
+
layers.append(PixelShuffle_ICNR(ni, **kwargs))
|
272 |
+
if last_cross:
|
273 |
+
layers.append(MergeLayer(dense=True))
|
274 |
+
ni += in_channels(encoder)
|
275 |
+
layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
|
276 |
+
layers += [
|
277 |
+
custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
|
278 |
+
]
|
279 |
+
if y_range is not None:
|
280 |
+
layers.append(SigmoidRange(*y_range))
|
281 |
+
super().__init__(*layers)
|
282 |
+
|
283 |
+
def __del__(self):
|
284 |
+
if hasattr(self, "sfs"):
|
285 |
+
self.sfs.remove()
|
DeOldify/deoldify/visualize.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.core import *
|
2 |
+
from fastai.vision import *
|
3 |
+
from matplotlib.axes import Axes
|
4 |
+
from .filters import IFilter, MasterFilter, ColorizerFilter
|
5 |
+
from .generators import gen_inference_deep, gen_inference_wide
|
6 |
+
from PIL import Image
|
7 |
+
import ffmpeg
|
8 |
+
import yt_dlp as youtube_dl
|
9 |
+
import gc
|
10 |
+
import requests
|
11 |
+
from io import BytesIO
|
12 |
+
import base64
|
13 |
+
from IPython import display as ipythondisplay
|
14 |
+
from IPython.display import HTML
|
15 |
+
from IPython.display import Image as ipythonimage
|
16 |
+
import cv2
|
17 |
+
import logging
|
18 |
+
|
19 |
+
# adapted from https://www.pyimagesearch.com/2016/04/25/watermarking-images-with-opencv-and-python/
|
20 |
+
def get_watermarked(pil_image: Image) -> Image:
|
21 |
+
try:
|
22 |
+
image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
23 |
+
(h, w) = image.shape[:2]
|
24 |
+
image = np.dstack([image, np.ones((h, w), dtype="uint8") * 255])
|
25 |
+
pct = 0.05
|
26 |
+
full_watermark = cv2.imread(
|
27 |
+
'./resource_images/watermark.png', cv2.IMREAD_UNCHANGED
|
28 |
+
)
|
29 |
+
(fwH, fwW) = full_watermark.shape[:2]
|
30 |
+
wH = int(pct * h)
|
31 |
+
wW = int((pct * h / fwH) * fwW)
|
32 |
+
watermark = cv2.resize(full_watermark, (wH, wW), interpolation=cv2.INTER_AREA)
|
33 |
+
overlay = np.zeros((h, w, 4), dtype="uint8")
|
34 |
+
(wH, wW) = watermark.shape[:2]
|
35 |
+
overlay[h - wH - 10 : h - 10, 10 : 10 + wW] = watermark
|
36 |
+
# blend the two images together using transparent overlays
|
37 |
+
output = image.copy()
|
38 |
+
cv2.addWeighted(overlay, 0.5, output, 1.0, 0, output)
|
39 |
+
rgb_image = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
40 |
+
final_image = Image.fromarray(rgb_image)
|
41 |
+
return final_image
|
42 |
+
except:
|
43 |
+
# Don't want this to crash everything, so let's just not watermark the image for now.
|
44 |
+
return pil_image
|
45 |
+
|
46 |
+
|
47 |
+
class ModelImageVisualizer:
|
48 |
+
def __init__(self, filter: IFilter, results_dir: str = None):
|
49 |
+
self.filter = filter
|
50 |
+
self.results_dir = None if results_dir is None else Path(results_dir)
|
51 |
+
self.results_dir.mkdir(parents=True, exist_ok=True)
|
52 |
+
|
53 |
+
def _clean_mem(self):
|
54 |
+
torch.cuda.empty_cache()
|
55 |
+
# gc.collect()
|
56 |
+
|
57 |
+
def _open_pil_image(self, path: Path) -> Image:
|
58 |
+
return PIL.Image.open(path).convert('RGB')
|
59 |
+
|
60 |
+
def _get_image_from_url(self, url: str) -> Image:
|
61 |
+
response = requests.get(url, timeout=30, headers={'user-agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/62.0.3202.94 Safari/537.36'})
|
62 |
+
img = PIL.Image.open(BytesIO(response.content)).convert('RGB')
|
63 |
+
return img
|
64 |
+
|
65 |
+
def plot_transformed_image_from_url(
|
66 |
+
self,
|
67 |
+
url: str,
|
68 |
+
path: str = 'test_images/image.png',
|
69 |
+
results_dir:Path = None,
|
70 |
+
figsize: Tuple[int, int] = (20, 20),
|
71 |
+
render_factor: int = None,
|
72 |
+
|
73 |
+
display_render_factor: bool = False,
|
74 |
+
compare: bool = False,
|
75 |
+
post_process: bool = True,
|
76 |
+
watermarked: bool = True,
|
77 |
+
) -> Path:
|
78 |
+
img = self._get_image_from_url(url)
|
79 |
+
img.save(path)
|
80 |
+
return self.plot_transformed_image(
|
81 |
+
path=path,
|
82 |
+
results_dir=results_dir,
|
83 |
+
figsize=figsize,
|
84 |
+
render_factor=render_factor,
|
85 |
+
display_render_factor=display_render_factor,
|
86 |
+
compare=compare,
|
87 |
+
post_process = post_process,
|
88 |
+
watermarked=watermarked,
|
89 |
+
)
|
90 |
+
|
91 |
+
def plot_transformed_image(
|
92 |
+
self,
|
93 |
+
path: str,
|
94 |
+
results_dir:Path = None,
|
95 |
+
figsize: Tuple[int, int] = (20, 20),
|
96 |
+
render_factor: int = None,
|
97 |
+
display_render_factor: bool = False,
|
98 |
+
compare: bool = False,
|
99 |
+
post_process: bool = True,
|
100 |
+
watermarked: bool = True,
|
101 |
+
) -> Path:
|
102 |
+
path = Path(path)
|
103 |
+
if results_dir is None:
|
104 |
+
results_dir = Path(self.results_dir)
|
105 |
+
result = self.get_transformed_image(
|
106 |
+
path, render_factor, post_process=post_process,watermarked=watermarked
|
107 |
+
)
|
108 |
+
orig = self._open_pil_image(path)
|
109 |
+
if compare:
|
110 |
+
self._plot_comparison(
|
111 |
+
figsize, render_factor, display_render_factor, orig, result
|
112 |
+
)
|
113 |
+
else:
|
114 |
+
self._plot_solo(figsize, render_factor, display_render_factor, result)
|
115 |
+
|
116 |
+
orig.close()
|
117 |
+
result_path = self._save_result_image(path, result, results_dir=results_dir)
|
118 |
+
result.close()
|
119 |
+
return result_path
|
120 |
+
|
121 |
+
def _plot_comparison(
|
122 |
+
self,
|
123 |
+
figsize: Tuple[int, int],
|
124 |
+
render_factor: int,
|
125 |
+
display_render_factor: bool,
|
126 |
+
orig: Image,
|
127 |
+
result: Image,
|
128 |
+
):
|
129 |
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
130 |
+
self._plot_image(
|
131 |
+
orig,
|
132 |
+
axes=axes[0],
|
133 |
+
figsize=figsize,
|
134 |
+
render_factor=render_factor,
|
135 |
+
display_render_factor=False,
|
136 |
+
)
|
137 |
+
self._plot_image(
|
138 |
+
result,
|
139 |
+
axes=axes[1],
|
140 |
+
figsize=figsize,
|
141 |
+
render_factor=render_factor,
|
142 |
+
display_render_factor=display_render_factor,
|
143 |
+
)
|
144 |
+
|
145 |
+
def _plot_solo(
|
146 |
+
self,
|
147 |
+
figsize: Tuple[int, int],
|
148 |
+
render_factor: int,
|
149 |
+
display_render_factor: bool,
|
150 |
+
result: Image,
|
151 |
+
):
|
152 |
+
fig, axes = plt.subplots(1, 1, figsize=figsize)
|
153 |
+
self._plot_image(
|
154 |
+
result,
|
155 |
+
axes=axes,
|
156 |
+
figsize=figsize,
|
157 |
+
render_factor=render_factor,
|
158 |
+
display_render_factor=display_render_factor,
|
159 |
+
)
|
160 |
+
|
161 |
+
def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:
|
162 |
+
if results_dir is None:
|
163 |
+
results_dir = Path(self.results_dir)
|
164 |
+
result_path = results_dir / source_path.name
|
165 |
+
image.save(result_path)
|
166 |
+
return result_path
|
167 |
+
|
168 |
+
def get_transformed_image(
|
169 |
+
self, path: Path, render_factor: int = None, post_process: bool = True,
|
170 |
+
watermarked: bool = True,
|
171 |
+
) -> Image:
|
172 |
+
self._clean_mem()
|
173 |
+
orig_image = self._open_pil_image(path)
|
174 |
+
filtered_image = self.filter.filter(
|
175 |
+
orig_image, orig_image, render_factor=render_factor,post_process=post_process
|
176 |
+
)
|
177 |
+
|
178 |
+
if watermarked:
|
179 |
+
return get_watermarked(filtered_image)
|
180 |
+
|
181 |
+
return filtered_image
|
182 |
+
|
183 |
+
def _plot_image(
|
184 |
+
self,
|
185 |
+
image: Image,
|
186 |
+
render_factor: int,
|
187 |
+
axes: Axes = None,
|
188 |
+
figsize=(20, 20),
|
189 |
+
display_render_factor = False,
|
190 |
+
):
|
191 |
+
if axes is None:
|
192 |
+
_, axes = plt.subplots(figsize=figsize)
|
193 |
+
axes.imshow(np.asarray(image) / 255)
|
194 |
+
axes.axis('off')
|
195 |
+
if render_factor is not None and display_render_factor:
|
196 |
+
plt.text(
|
197 |
+
10,
|
198 |
+
10,
|
199 |
+
'render_factor: ' + str(render_factor),
|
200 |
+
color='white',
|
201 |
+
backgroundcolor='black',
|
202 |
+
)
|
203 |
+
|
204 |
+
def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]:
|
205 |
+
columns = min(num_images, max_columns)
|
206 |
+
rows = num_images // columns
|
207 |
+
rows = rows if rows * columns == num_images else rows + 1
|
208 |
+
return rows, columns
|
209 |
+
|
210 |
+
|
211 |
+
class VideoColorizer:
|
212 |
+
def __init__(self, vis: ModelImageVisualizer):
|
213 |
+
self.vis = vis
|
214 |
+
workfolder = Path('./video')
|
215 |
+
self.source_folder = workfolder / "source"
|
216 |
+
self.bwframes_root = workfolder / "bwframes"
|
217 |
+
self.audio_root = workfolder / "audio"
|
218 |
+
self.colorframes_root = workfolder / "colorframes"
|
219 |
+
self.result_folder = workfolder / "result"
|
220 |
+
|
221 |
+
def _purge_images(self, dir):
|
222 |
+
for f in os.listdir(dir):
|
223 |
+
if re.search('.*?\.jpg', f):
|
224 |
+
os.remove(os.path.join(dir, f))
|
225 |
+
|
226 |
+
def _get_ffmpeg_probe(self, path:Path):
|
227 |
+
try:
|
228 |
+
probe = ffmpeg.probe(str(path))
|
229 |
+
return probe
|
230 |
+
except ffmpeg.Error as e:
|
231 |
+
logging.error("ffmpeg error: {0}".format(e), exc_info=True)
|
232 |
+
logging.error('stdout:' + e.stdout.decode('UTF-8'))
|
233 |
+
logging.error('stderr:' + e.stderr.decode('UTF-8'))
|
234 |
+
raise e
|
235 |
+
except Exception as e:
|
236 |
+
logging.error('Failed to instantiate ffmpeg.probe. Details: {0}'.format(e), exc_info=True)
|
237 |
+
raise e
|
238 |
+
|
239 |
+
def _get_fps(self, source_path: Path) -> str:
|
240 |
+
probe = self._get_ffmpeg_probe(source_path)
|
241 |
+
stream_data = next(
|
242 |
+
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
|
243 |
+
None,
|
244 |
+
)
|
245 |
+
return stream_data['avg_frame_rate']
|
246 |
+
|
247 |
+
def _download_video_from_url(self, source_url, source_path: Path):
|
248 |
+
if source_path.exists():
|
249 |
+
source_path.unlink()
|
250 |
+
|
251 |
+
ydl_opts = {
|
252 |
+
'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',
|
253 |
+
'outtmpl': str(source_path),
|
254 |
+
'retries': 30,
|
255 |
+
'fragment-retries': 30
|
256 |
+
}
|
257 |
+
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
|
258 |
+
ydl.download([source_url])
|
259 |
+
|
260 |
+
def _extract_raw_frames(self, source_path: Path):
|
261 |
+
bwframes_folder = self.bwframes_root / (source_path.stem)
|
262 |
+
bwframe_path_template = str(bwframes_folder / '%5d.jpg')
|
263 |
+
bwframes_folder.mkdir(parents=True, exist_ok=True)
|
264 |
+
self._purge_images(bwframes_folder)
|
265 |
+
|
266 |
+
process = (
|
267 |
+
ffmpeg
|
268 |
+
.input(str(source_path))
|
269 |
+
.output(str(bwframe_path_template), format='image2', vcodec='mjpeg', **{'q:v':'0'})
|
270 |
+
.global_args('-hide_banner')
|
271 |
+
.global_args('-nostats')
|
272 |
+
.global_args('-loglevel', 'error')
|
273 |
+
)
|
274 |
+
|
275 |
+
try:
|
276 |
+
process.run()
|
277 |
+
except ffmpeg.Error as e:
|
278 |
+
logging.error("ffmpeg error: {0}".format(e), exc_info=True)
|
279 |
+
logging.error('stdout:' + e.stdout.decode('UTF-8'))
|
280 |
+
logging.error('stderr:' + e.stderr.decode('UTF-8'))
|
281 |
+
raise e
|
282 |
+
except Exception as e:
|
283 |
+
logging.error('Errror while extracting raw frames from source video. Details: {0}'.format(e), exc_info=True)
|
284 |
+
raise e
|
285 |
+
|
286 |
+
def _colorize_raw_frames(
|
287 |
+
self, source_path: Path, render_factor: int = None, post_process: bool = True,
|
288 |
+
watermarked: bool = True,
|
289 |
+
):
|
290 |
+
colorframes_folder = self.colorframes_root / (source_path.stem)
|
291 |
+
colorframes_folder.mkdir(parents=True, exist_ok=True)
|
292 |
+
self._purge_images(colorframes_folder)
|
293 |
+
bwframes_folder = self.bwframes_root / (source_path.stem)
|
294 |
+
|
295 |
+
for img in progress_bar(os.listdir(str(bwframes_folder))):
|
296 |
+
img_path = bwframes_folder / img
|
297 |
+
|
298 |
+
if os.path.isfile(str(img_path)):
|
299 |
+
color_image = self.vis.get_transformed_image(
|
300 |
+
str(img_path), render_factor=render_factor, post_process=post_process,watermarked=watermarked
|
301 |
+
)
|
302 |
+
color_image.save(str(colorframes_folder / img))
|
303 |
+
|
304 |
+
def _build_video(self, source_path: Path) -> Path:
|
305 |
+
colorized_path = self.result_folder / (
|
306 |
+
source_path.name.replace('.mp4', '_no_audio.mp4')
|
307 |
+
)
|
308 |
+
colorframes_folder = self.colorframes_root / (source_path.stem)
|
309 |
+
colorframes_path_template = str(colorframes_folder / '%5d.jpg')
|
310 |
+
colorized_path.parent.mkdir(parents=True, exist_ok=True)
|
311 |
+
if colorized_path.exists():
|
312 |
+
colorized_path.unlink()
|
313 |
+
fps = self._get_fps(source_path)
|
314 |
+
|
315 |
+
process = (
|
316 |
+
ffmpeg
|
317 |
+
.input(str(colorframes_path_template), format='image2', vcodec='mjpeg', framerate=fps)
|
318 |
+
.output(str(colorized_path), crf=17, vcodec='libx264')
|
319 |
+
.global_args('-hide_banner')
|
320 |
+
.global_args('-nostats')
|
321 |
+
.global_args('-loglevel', 'error')
|
322 |
+
)
|
323 |
+
|
324 |
+
try:
|
325 |
+
process.run()
|
326 |
+
except ffmpeg.Error as e:
|
327 |
+
logging.error("ffmpeg error: {0}".format(e), exc_info=True)
|
328 |
+
logging.error('stdout:' + e.stdout.decode('UTF-8'))
|
329 |
+
logging.error('stderr:' + e.stderr.decode('UTF-8'))
|
330 |
+
raise e
|
331 |
+
except Exception as e:
|
332 |
+
logging.error('Errror while building output video. Details: {0}'.format(e), exc_info=True)
|
333 |
+
raise e
|
334 |
+
|
335 |
+
result_path = self.result_folder / source_path.name
|
336 |
+
if result_path.exists():
|
337 |
+
result_path.unlink()
|
338 |
+
# making copy of non-audio version in case adding back audio doesn't apply or fails.
|
339 |
+
shutil.copyfile(str(colorized_path), str(result_path))
|
340 |
+
|
341 |
+
# adding back sound here
|
342 |
+
audio_file = Path(str(source_path).replace('.mp4', '.aac'))
|
343 |
+
if audio_file.exists():
|
344 |
+
audio_file.unlink()
|
345 |
+
|
346 |
+
os.system(
|
347 |
+
'ffmpeg -y -i "'
|
348 |
+
+ str(source_path)
|
349 |
+
+ '" -vn -acodec copy "'
|
350 |
+
+ str(audio_file)
|
351 |
+
+ '"'
|
352 |
+
+ ' -hide_banner'
|
353 |
+
+ ' -nostats'
|
354 |
+
+ ' -loglevel error'
|
355 |
+
)
|
356 |
+
|
357 |
+
if audio_file.exists():
|
358 |
+
os.system(
|
359 |
+
'ffmpeg -y -i "'
|
360 |
+
+ str(colorized_path)
|
361 |
+
+ '" -i "'
|
362 |
+
+ str(audio_file)
|
363 |
+
+ '" -shortest -c:v copy -c:a aac -b:a 256k "'
|
364 |
+
+ str(result_path)
|
365 |
+
+ '"'
|
366 |
+
+ ' -hide_banner'
|
367 |
+
+ ' -nostats'
|
368 |
+
+ ' -loglevel error'
|
369 |
+
)
|
370 |
+
logging.info('Video created here: ' + str(result_path))
|
371 |
+
return result_path
|
372 |
+
|
373 |
+
def colorize_from_url(
|
374 |
+
self,
|
375 |
+
source_url,
|
376 |
+
file_name: str,
|
377 |
+
render_factor: int = None,
|
378 |
+
post_process: bool = True,
|
379 |
+
watermarked: bool = True,
|
380 |
+
|
381 |
+
) -> Path:
|
382 |
+
source_path = self.source_folder / file_name
|
383 |
+
self._download_video_from_url(source_url, source_path)
|
384 |
+
return self._colorize_from_path(
|
385 |
+
source_path, render_factor=render_factor, post_process=post_process,watermarked=watermarked
|
386 |
+
)
|
387 |
+
|
388 |
+
def colorize_from_file_name(
|
389 |
+
self, file_name: str, render_factor: int = None, watermarked: bool = True, post_process: bool = True,
|
390 |
+
) -> Path:
|
391 |
+
source_path = self.source_folder / file_name
|
392 |
+
return self._colorize_from_path(
|
393 |
+
source_path, render_factor=render_factor, post_process=post_process,watermarked=watermarked
|
394 |
+
)
|
395 |
+
|
396 |
+
def _colorize_from_path(
|
397 |
+
self, source_path: Path, render_factor: int = None, watermarked: bool = True, post_process: bool = True
|
398 |
+
) -> Path:
|
399 |
+
if not source_path.exists():
|
400 |
+
raise Exception(
|
401 |
+
'Video at path specfied, ' + str(source_path) + ' could not be found.'
|
402 |
+
)
|
403 |
+
self._extract_raw_frames(source_path)
|
404 |
+
self._colorize_raw_frames(
|
405 |
+
source_path, render_factor=render_factor,post_process=post_process,watermarked=watermarked
|
406 |
+
)
|
407 |
+
return self._build_video(source_path)
|
408 |
+
|
409 |
+
|
410 |
+
def get_video_colorizer(render_factor: int = 21) -> VideoColorizer:
|
411 |
+
return get_stable_video_colorizer(render_factor=render_factor)
|
412 |
+
|
413 |
+
|
414 |
+
def get_artistic_video_colorizer(
|
415 |
+
root_folder: Path = Path('./'),
|
416 |
+
weights_name: str = 'ColorizeArtistic_gen',
|
417 |
+
results_dir='result_images',
|
418 |
+
render_factor: int = 35
|
419 |
+
) -> VideoColorizer:
|
420 |
+
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
|
421 |
+
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
422 |
+
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
423 |
+
return VideoColorizer(vis)
|
424 |
+
|
425 |
+
|
426 |
+
def get_stable_video_colorizer(
|
427 |
+
root_folder: Path = Path('./'),
|
428 |
+
weights_name: str = 'ColorizeVideo_gen',
|
429 |
+
results_dir='result_images',
|
430 |
+
render_factor: int = 21
|
431 |
+
) -> VideoColorizer:
|
432 |
+
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
|
433 |
+
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
434 |
+
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
435 |
+
return VideoColorizer(vis)
|
436 |
+
|
437 |
+
|
438 |
+
def get_image_colorizer(
|
439 |
+
root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
|
440 |
+
) -> ModelImageVisualizer:
|
441 |
+
if artistic:
|
442 |
+
return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
|
443 |
+
else:
|
444 |
+
return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
|
445 |
+
|
446 |
+
|
447 |
+
def get_stable_image_colorizer(
|
448 |
+
root_folder: Path = Path('./'),
|
449 |
+
weights_name: str = 'ColorizeStable_gen',
|
450 |
+
results_dir='result_images',
|
451 |
+
render_factor: int = 35
|
452 |
+
) -> ModelImageVisualizer:
|
453 |
+
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
|
454 |
+
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
455 |
+
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
456 |
+
return vis
|
457 |
+
|
458 |
+
|
459 |
+
def get_artistic_image_colorizer(
|
460 |
+
root_folder: Path = Path('./'),
|
461 |
+
weights_name: str = 'ColorizeArtistic_gen',
|
462 |
+
results_dir='result_images',
|
463 |
+
render_factor: int = 35
|
464 |
+
) -> ModelImageVisualizer:
|
465 |
+
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
|
466 |
+
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
467 |
+
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
468 |
+
return vis
|
469 |
+
|
470 |
+
|
471 |
+
def show_image_in_notebook(image_path: Path):
|
472 |
+
ipythondisplay.display(ipythonimage(str(image_path)))
|
473 |
+
|
474 |
+
|
475 |
+
def show_video_in_notebook(video_path: Path):
|
476 |
+
video = io.open(video_path, 'r+b').read()
|
477 |
+
encoded = base64.b64encode(video)
|
478 |
+
ipythondisplay.display(
|
479 |
+
HTML(
|
480 |
+
data='''<video alt="test" autoplay
|
481 |
+
loop controls style="height: 400px;">
|
482 |
+
<source src="data:video/mp4;base64,{0}" type="video/mp4" />
|
483 |
+
</video>'''.format(
|
484 |
+
encoded.decode('ascii')
|
485 |
+
)
|
486 |
+
)
|
487 |
+
)
|
DeOldify/environment.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: deoldify
|
2 |
+
channels:
|
3 |
+
- fastai
|
4 |
+
- conda-forge
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- pip
|
8 |
+
- fastai=1.0.60
|
9 |
+
- mkl=2024.0
|
10 |
+
- python=3.10
|
11 |
+
- pytorch::pytorch=1.11.0
|
12 |
+
- pytorch::torchvision
|
13 |
+
- pytorch::torchaudio
|
14 |
+
- tensorboardX
|
15 |
+
- jupyterlab
|
16 |
+
- pillow>=9.0.0
|
17 |
+
- ipywidgets
|
18 |
+
- ffmpeg
|
19 |
+
- pip:
|
20 |
+
- ffmpeg-python
|
21 |
+
- opencv-python>=4.2.0.32
|
22 |
+
- wandb
|
23 |
+
- yt-dlp
|
DeOldify/fastai/LICENSE
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License, Version 2.0 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/
|
2 |
+
|
3 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
4 |
+
|
5 |
+
1. Definitions.
|
6 |
+
|
7 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
8 |
+
|
9 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
10 |
+
|
11 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
12 |
+
|
13 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
14 |
+
|
15 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
16 |
+
|
17 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
18 |
+
|
19 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
20 |
+
|
21 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
22 |
+
|
23 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
24 |
+
|
25 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
26 |
+
|
27 |
+
2. Grant of Copyright License.
|
28 |
+
|
29 |
+
Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
30 |
+
|
31 |
+
3. Grant of Patent License.
|
32 |
+
|
33 |
+
Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
34 |
+
|
35 |
+
4. Redistribution.
|
36 |
+
|
37 |
+
You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
38 |
+
|
39 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
40 |
+
|
41 |
+
5. Submission of Contributions.
|
42 |
+
|
43 |
+
Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
44 |
+
|
45 |
+
6. Trademarks.
|
46 |
+
|
47 |
+
This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
48 |
+
|
49 |
+
7. Disclaimer of Warranty.
|
50 |
+
|
51 |
+
Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
52 |
+
|
53 |
+
8. Limitation of Liability.
|
54 |
+
|
55 |
+
In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
56 |
+
|
57 |
+
9. Accepting Warranty or Additional Liability.
|
58 |
+
|
59 |
+
While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
60 |
+
|
DeOldify/fastai/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .version import __version__
|
2 |
+
|
DeOldify/fastai/basic_data.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"`fastai.data` loads and manages datasets with `DataBunch`"
|
2 |
+
from .torch_core import *
|
3 |
+
from torch.utils.data.dataloader import default_collate
|
4 |
+
|
5 |
+
DatasetType = Enum('DatasetType', 'Train Valid Test Single Fix')
|
6 |
+
__all__ = ['DataBunch', 'DeviceDataLoader', 'DatasetType', 'load_data']
|
7 |
+
|
8 |
+
old_dl_init = torch.utils.data.DataLoader.__init__
|
9 |
+
|
10 |
+
def intercept_args(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
11 |
+
num_workers=0, collate_fn=default_collate, pin_memory=True, drop_last=False,
|
12 |
+
timeout=0, worker_init_fn=None):
|
13 |
+
self.init_kwargs = {'batch_size':batch_size, 'shuffle':shuffle, 'sampler':sampler, 'batch_sampler':batch_sampler,
|
14 |
+
'num_workers':num_workers, 'collate_fn':collate_fn, 'pin_memory':pin_memory,
|
15 |
+
'drop_last': drop_last, 'timeout':timeout, 'worker_init_fn':worker_init_fn}
|
16 |
+
old_dl_init(self, dataset, **self.init_kwargs)
|
17 |
+
|
18 |
+
torch.utils.data.DataLoader.__init__ = intercept_args
|
19 |
+
|
20 |
+
def DataLoader___getattr__(dl, k:str)->Any: return getattr(dl.dataset, k)
|
21 |
+
DataLoader.__getattr__ = DataLoader___getattr__
|
22 |
+
|
23 |
+
def DataLoader___setstate__(dl, data:Any): dl.__dict__.update(data)
|
24 |
+
DataLoader.__setstate__ = DataLoader___setstate__
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class DeviceDataLoader():
|
28 |
+
"Bind a `DataLoader` to a `torch.device`."
|
29 |
+
dl: DataLoader
|
30 |
+
device: torch.device
|
31 |
+
tfms: List[Callable]=None
|
32 |
+
collate_fn: Callable=data_collate
|
33 |
+
def __post_init__(self):
|
34 |
+
self.dl.collate_fn=self.collate_fn
|
35 |
+
self.tfms = listify(self.tfms)
|
36 |
+
|
37 |
+
def __len__(self)->int: return len(self.dl)
|
38 |
+
def __getattr__(self,k:str)->Any: return getattr(self.dl, k)
|
39 |
+
def __setstate__(self,data:Any): self.__dict__.update(data)
|
40 |
+
|
41 |
+
@property
|
42 |
+
def batch_size(self): return self.dl.batch_size
|
43 |
+
@batch_size.setter
|
44 |
+
def batch_size(self,v):
|
45 |
+
new_kwargs = {**self.dl.init_kwargs, 'batch_size':v, 'collate_fn':self.collate_fn}
|
46 |
+
self.dl = self.dl.__class__(self.dl.dataset, **new_kwargs)
|
47 |
+
if hasattr(self.dl.dataset, 'bs'): self.dl.dataset.bs = v
|
48 |
+
|
49 |
+
@property
|
50 |
+
def num_workers(self): return self.dl.num_workers
|
51 |
+
@num_workers.setter
|
52 |
+
def num_workers(self,v): self.dl.num_workers = v
|
53 |
+
|
54 |
+
def add_tfm(self,tfm:Callable)->None:
|
55 |
+
"Add `tfm` to `self.tfms`."
|
56 |
+
self.tfms.append(tfm)
|
57 |
+
def remove_tfm(self,tfm:Callable)->None:
|
58 |
+
"Remove `tfm` from `self.tfms`."
|
59 |
+
if tfm in self.tfms: self.tfms.remove(tfm)
|
60 |
+
|
61 |
+
def new(self, **kwargs):
|
62 |
+
"Create a new copy of `self` with `kwargs` replacing current values."
|
63 |
+
new_kwargs = {**self.dl.init_kwargs, **kwargs}
|
64 |
+
return DeviceDataLoader(self.dl.__class__(self.dl.dataset, **new_kwargs), self.device, self.tfms,
|
65 |
+
self.collate_fn)
|
66 |
+
|
67 |
+
def proc_batch(self,b:Tensor)->Tensor:
|
68 |
+
"Process batch `b` of `TensorImage`."
|
69 |
+
b = to_device(b, self.device)
|
70 |
+
for f in listify(self.tfms): b = f(b)
|
71 |
+
return b
|
72 |
+
|
73 |
+
def __iter__(self):
|
74 |
+
"Process and returns items from `DataLoader`."
|
75 |
+
for b in self.dl: yield self.proc_batch(b)
|
76 |
+
|
77 |
+
@classmethod
|
78 |
+
def create(cls, dataset:Dataset, bs:int=64, shuffle:bool=False, device:torch.device=defaults.device,
|
79 |
+
tfms:Collection[Callable]=tfms, num_workers:int=defaults.cpus, collate_fn:Callable=data_collate, **kwargs:Any):
|
80 |
+
"Create DeviceDataLoader from `dataset` with `bs` and `shuffle`: process using `num_workers`."
|
81 |
+
return cls(DataLoader(dataset, batch_size=bs, shuffle=shuffle, num_workers=num_workers, **kwargs),
|
82 |
+
device=device, tfms=tfms, collate_fn=collate_fn)
|
83 |
+
|
84 |
+
class DataBunch():
|
85 |
+
"Bind `train_dl`,`valid_dl` and `test_dl` in a data object."
|
86 |
+
|
87 |
+
def __init__(self, train_dl:DataLoader, valid_dl:DataLoader, fix_dl:DataLoader=None, test_dl:Optional[DataLoader]=None,
|
88 |
+
device:torch.device=None, dl_tfms:Optional[Collection[Callable]]=None, path:PathOrStr='.',
|
89 |
+
collate_fn:Callable=data_collate, no_check:bool=False):
|
90 |
+
self.dl_tfms = listify(dl_tfms)
|
91 |
+
self.device = defaults.device if device is None else device
|
92 |
+
assert not isinstance(train_dl,DeviceDataLoader)
|
93 |
+
def _create_dl(dl, **kwargs):
|
94 |
+
if dl is None: return None
|
95 |
+
return DeviceDataLoader(dl, self.device, self.dl_tfms, collate_fn, **kwargs)
|
96 |
+
self.train_dl,self.valid_dl,self.fix_dl,self.test_dl = map(_create_dl, [train_dl,valid_dl,fix_dl,test_dl])
|
97 |
+
if fix_dl is None: self.fix_dl = self.train_dl.new(shuffle=False, drop_last=False)
|
98 |
+
self.single_dl = _create_dl(DataLoader(valid_dl.dataset, batch_size=1, num_workers=0))
|
99 |
+
self.path = Path(path)
|
100 |
+
if not no_check: self.sanity_check()
|
101 |
+
|
102 |
+
def __repr__(self)->str:
|
103 |
+
return f'{self.__class__.__name__};\n\nTrain: {self.train_ds};\n\nValid: {self.valid_ds};\n\nTest: {self.test_ds}'
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def _init_ds(train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None):
|
107 |
+
# train_ds, but without training tfms
|
108 |
+
fix_ds = valid_ds.new(train_ds.x, train_ds.y) if hasattr(valid_ds,'new') else train_ds
|
109 |
+
return [o for o in (train_ds,valid_ds,fix_ds,test_ds) if o is not None]
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,
|
113 |
+
val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,
|
114 |
+
device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, **dl_kwargs)->'DataBunch':
|
115 |
+
"Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`"
|
116 |
+
datasets = cls._init_ds(train_ds, valid_ds, test_ds)
|
117 |
+
val_bs = ifnone(val_bs, bs)
|
118 |
+
dls = [DataLoader(d, b, shuffle=s, drop_last=s, num_workers=num_workers, **dl_kwargs) for d,b,s in
|
119 |
+
zip(datasets, (bs,val_bs,val_bs,val_bs), (True,False,False,False)) if d is not None]
|
120 |
+
return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)
|
121 |
+
|
122 |
+
def __getattr__(self,k:int)->Any: return getattr(self.train_dl, k)
|
123 |
+
def __setstate__(self,data:Any): self.__dict__.update(data)
|
124 |
+
|
125 |
+
def dl(self, ds_type:DatasetType=DatasetType.Valid)->DeviceDataLoader:
|
126 |
+
"Returns appropriate `Dataset` for validation, training, or test (`ds_type`)."
|
127 |
+
#TODO: refactor
|
128 |
+
return (self.train_dl if ds_type == DatasetType.Train else
|
129 |
+
self.test_dl if ds_type == DatasetType.Test else
|
130 |
+
self.valid_dl if ds_type == DatasetType.Valid else
|
131 |
+
self.single_dl if ds_type == DatasetType.Single else
|
132 |
+
self.fix_dl)
|
133 |
+
|
134 |
+
@property
|
135 |
+
def dls(self)->List[DeviceDataLoader]:
|
136 |
+
"Returns a list of all DeviceDataLoaders. If you need a specific DeviceDataLoader, access via the relevant property (`train_dl`, `valid_dl`, etc) as the index of DLs in this list is not guaranteed to remain constant."
|
137 |
+
res = [self.train_dl, self.fix_dl, self.single_dl]
|
138 |
+
# Preserve the original ordering of Train, Valid, Fix, Single, Test Data Loaders
|
139 |
+
# (Unknown/not verified as of 1.0.47 whether there are other methods explicitly using DLs their list index)
|
140 |
+
if self.valid_dl: res.insert(1, self.valid_dl)
|
141 |
+
return res if not self.test_dl else res + [self.test_dl]
|
142 |
+
|
143 |
+
def add_tfm(self,tfm:Callable)->None:
|
144 |
+
for dl in self.dls: dl.add_tfm(tfm)
|
145 |
+
|
146 |
+
def remove_tfm(self,tfm:Callable)->None:
|
147 |
+
for dl in self.dls: dl.remove_tfm(tfm)
|
148 |
+
|
149 |
+
def save(self, file:PathLikeOrBinaryStream= 'data_save.pkl')->None:
|
150 |
+
"Save the `DataBunch` in `self.path/file`. `file` can be file-like (file or buffer)"
|
151 |
+
if not getattr(self, 'label_list', False):
|
152 |
+
warn("Serializing the `DataBunch` only works when you created it using the data block API.")
|
153 |
+
return
|
154 |
+
try_save(self.label_list, self.path, file)
|
155 |
+
|
156 |
+
def add_test(self, items:Iterator, label:Any=None, tfms=None, tfm_y=None)->None:
|
157 |
+
"Add the `items` as a test set. Pass along `label` otherwise label them with `EmptyLabel`."
|
158 |
+
self.label_list.add_test(items, label=label, tfms=tfms, tfm_y=tfm_y)
|
159 |
+
vdl = self.valid_dl
|
160 |
+
dl = DataLoader(self.label_list.test, vdl.batch_size, shuffle=False, drop_last=False, num_workers=vdl.num_workers)
|
161 |
+
self.test_dl = DeviceDataLoader(dl, vdl.device, vdl.tfms, vdl.collate_fn)
|
162 |
+
|
163 |
+
def one_batch(self, ds_type:DatasetType=DatasetType.Train, detach:bool=True, denorm:bool=True, cpu:bool=True)->Collection[Tensor]:
|
164 |
+
"Get one batch from the data loader of `ds_type`. Optionally `detach` and `denorm`."
|
165 |
+
dl = self.dl(ds_type)
|
166 |
+
w = self.num_workers
|
167 |
+
self.num_workers = 0
|
168 |
+
try: x,y = next(iter(dl))
|
169 |
+
finally: self.num_workers = w
|
170 |
+
if detach: x,y = to_detach(x,cpu=cpu),to_detach(y,cpu=cpu)
|
171 |
+
norm = getattr(self,'norm',False)
|
172 |
+
if denorm and norm:
|
173 |
+
x = self.denorm(x)
|
174 |
+
if norm.keywords.get('do_y',False): y = self.denorm(y, do_x=True)
|
175 |
+
return x,y
|
176 |
+
|
177 |
+
def one_item(self, item, detach:bool=False, denorm:bool=False, cpu:bool=False):
|
178 |
+
"Get `item` into a batch. Optionally `detach` and `denorm`."
|
179 |
+
ds = self.single_ds
|
180 |
+
with ds.set_item(item):
|
181 |
+
return self.one_batch(ds_type=DatasetType.Single, detach=detach, denorm=denorm, cpu=cpu)
|
182 |
+
|
183 |
+
def show_batch(self, rows:int=5, ds_type:DatasetType=DatasetType.Train, reverse:bool=False, **kwargs)->None:
|
184 |
+
"Show a batch of data in `ds_type` on a few `rows`."
|
185 |
+
x,y = self.one_batch(ds_type, True, True)
|
186 |
+
if reverse: x,y = x.flip(0),y.flip(0)
|
187 |
+
n_items = rows **2 if self.train_ds.x._square_show else rows
|
188 |
+
if self.dl(ds_type).batch_size < n_items: n_items = self.dl(ds_type).batch_size
|
189 |
+
xs = [self.train_ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)]
|
190 |
+
#TODO: get rid of has_arg if possible
|
191 |
+
if has_arg(self.train_ds.y.reconstruct, 'x'):
|
192 |
+
ys = [self.train_ds.y.reconstruct(grab_idx(y, i), x=x) for i,x in enumerate(xs)]
|
193 |
+
else : ys = [self.train_ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)]
|
194 |
+
self.train_ds.x.show_xys(xs, ys, **kwargs)
|
195 |
+
|
196 |
+
def export(self, file:PathLikeOrBinaryStream='export.pkl'):
|
197 |
+
"Export the minimal state of `self` for inference in `self.path/file`. `file` can be file-like (file or buffer)"
|
198 |
+
xtra = dict(normalize=self.norm.keywords) if getattr(self, 'norm', False) else {}
|
199 |
+
try_save(self.valid_ds.get_state(**xtra), self.path, file)
|
200 |
+
|
201 |
+
def _grab_dataset(self, dl:DataLoader):
|
202 |
+
ds = dl.dl.dataset
|
203 |
+
while hasattr(ds, 'dataset'): ds = ds.dataset
|
204 |
+
return ds
|
205 |
+
|
206 |
+
@property
|
207 |
+
def train_ds(self)->Dataset: return self._grab_dataset(self.train_dl)
|
208 |
+
@property
|
209 |
+
def valid_ds(self)->Dataset: return self._grab_dataset(self.valid_dl)
|
210 |
+
@property
|
211 |
+
def single_ds(self)->Dataset: return self._grab_dataset(self.single_dl)
|
212 |
+
@property
|
213 |
+
def loss_func(self)->OptLossFunc:
|
214 |
+
return getattr(self.train_ds.y, 'loss_func', F.nll_loss) if hasattr(self.train_ds, 'y') else F.nll_loss
|
215 |
+
|
216 |
+
@property
|
217 |
+
def test_ds(self)->Dataset:
|
218 |
+
return self._grab_dataset(self.test_dl) if self.test_dl is not None else None
|
219 |
+
|
220 |
+
@property
|
221 |
+
def empty_val(self)->bool:
|
222 |
+
if not hasattr(self, 'valid_dl') or self.valid_dl is None: return True
|
223 |
+
if hasattr(self.valid_ds, 'items') and len(self.valid_ds.items) == 0: return True
|
224 |
+
return (len(self.valid_ds) == 0)
|
225 |
+
|
226 |
+
@property
|
227 |
+
def is_empty(self)->bool:
|
228 |
+
return not ((self.train_dl and len(self.train_ds.items) != 0) or
|
229 |
+
(self.valid_dl and len(self.valid_ds.items) != 0) or
|
230 |
+
(self.test_dl and len(self.test_ds.items) != 0))
|
231 |
+
|
232 |
+
@property
|
233 |
+
def batch_size(self): return self.train_dl.batch_size
|
234 |
+
@batch_size.setter
|
235 |
+
def batch_size(self,v):
|
236 |
+
self.train_dl.batch_size,self.valid_dl.batch_size = v,v
|
237 |
+
if self.test_dl is not None: self.test_dl.batch_size = v
|
238 |
+
|
239 |
+
def sanity_check(self):
|
240 |
+
"Check the underlying data in the training set can be properly loaded."
|
241 |
+
final_message = "You can deactivate this warning by passing `no_check=True`."
|
242 |
+
if not hasattr(self.train_ds, 'items') or len(self.train_ds.items) == 0 or not hasattr(self.train_dl, 'batch_sampler'): return
|
243 |
+
if len(self.train_dl) == 0:
|
244 |
+
warn(f"""Your training dataloader is empty, you have only {len(self.train_dl.dataset)} items in your training set.
|
245 |
+
Your batch size is {self.train_dl.batch_size}, you should lower it.""")
|
246 |
+
print(final_message)
|
247 |
+
return
|
248 |
+
idx = next(iter(self.train_dl.batch_sampler))
|
249 |
+
samples,fails = [],[]
|
250 |
+
for i in idx:
|
251 |
+
try: samples.append(self.train_dl.dataset[i])
|
252 |
+
except: fails.append(i)
|
253 |
+
if len(fails) > 0:
|
254 |
+
warn_msg = "There seems to be something wrong with your dataset, for example, in the first batch can't access"
|
255 |
+
if len(fails) == len(idx):
|
256 |
+
warn_msg += f" any element of self.train_ds.\nTried: {show_some(idx)}"
|
257 |
+
else:
|
258 |
+
warn_msg += f" these elements in self.train_ds: {show_some(fails)}"
|
259 |
+
warn(warn_msg)
|
260 |
+
print(final_message)
|
261 |
+
return
|
262 |
+
try: batch = self.collate_fn(samples)
|
263 |
+
except:
|
264 |
+
message = "It's not possible to collate samples of your dataset together in a batch."
|
265 |
+
try:
|
266 |
+
shapes = [[o[i].data.shape for o in samples] for i in range(2)]
|
267 |
+
message += f'\nShapes of the inputs/targets:\n{shapes}'
|
268 |
+
except: pass
|
269 |
+
warn(message)
|
270 |
+
print(final_message)
|
271 |
+
|
272 |
+
def load_data(path:PathOrStr, file:PathLikeOrBinaryStream='data_save.pkl', bs:int=64, val_bs:int=None, num_workers:int=defaults.cpus,
|
273 |
+
dl_tfms:Optional[Collection[Callable]]=None, device:torch.device=None, collate_fn:Callable=data_collate,
|
274 |
+
no_check:bool=False, **kwargs)->DataBunch:
|
275 |
+
"Load a saved `DataBunch` from `path/file`. `file` can be file-like (file or buffer)"
|
276 |
+
source = Path(path)/file if is_pathlike(file) else file
|
277 |
+
ll = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source)
|
278 |
+
return ll.databunch(path=path, bs=bs, val_bs=val_bs, num_workers=num_workers, dl_tfms=dl_tfms, device=device,
|
279 |
+
collate_fn=collate_fn, no_check=no_check, **kwargs)
|
DeOldify/fastai/basic_train.py
ADDED
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Provides basic training and validation with `Learner`"
|
2 |
+
from .torch_core import *
|
3 |
+
from .basic_data import *
|
4 |
+
from .callback import *
|
5 |
+
from .data_block import *
|
6 |
+
from .utils.ipython import gpu_mem_restore
|
7 |
+
import inspect
|
8 |
+
from fastprogress.fastprogress import format_time, IN_NOTEBOOK
|
9 |
+
from time import time
|
10 |
+
from fastai.sixel import plot_sixel
|
11 |
+
|
12 |
+
__all__ = ['Learner', 'LearnerCallback', 'Recorder', 'RecordOnCPU', 'fit', 'loss_batch', 'train_epoch', 'validate',
|
13 |
+
'get_preds', 'load_learner']
|
14 |
+
|
15 |
+
defaults.lr = slice(3e-3)
|
16 |
+
defaults.wd = 1e-2
|
17 |
+
defaults.extra_callbacks = None
|
18 |
+
defaults.extra_callback_fns = None
|
19 |
+
|
20 |
+
def loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None,
|
21 |
+
cb_handler:Optional[CallbackHandler]=None, count:[int]=[1], batch_multiplier:int=1)->Tuple[Union[Tensor,int,float,str]]:
|
22 |
+
"Calculate loss and metrics for a batch, call out to callbacks as necessary."
|
23 |
+
cb_handler = ifnone(cb_handler, CallbackHandler())
|
24 |
+
if not is_listy(xb): xb = [xb]
|
25 |
+
if not is_listy(yb): yb = [yb]
|
26 |
+
out = model(*xb)
|
27 |
+
|
28 |
+
if not loss_func: return to_detach(out), yb[0].detach()
|
29 |
+
out = cb_handler.on_loss_begin(out)
|
30 |
+
loss = loss_func(out, *yb)/batch_multiplier
|
31 |
+
count[0]-=1
|
32 |
+
|
33 |
+
if opt is not None:
|
34 |
+
loss,skip_bwd = cb_handler.on_backward_begin(loss)
|
35 |
+
if not skip_bwd: loss.backward()
|
36 |
+
if count[0] == 0:
|
37 |
+
if not cb_handler.on_backward_end(): opt.step()
|
38 |
+
if not cb_handler.on_step_end(): opt.zero_grad()
|
39 |
+
count[0] = batch_multiplier
|
40 |
+
|
41 |
+
return loss.detach().cpu()
|
42 |
+
|
43 |
+
def get_preds(model:nn.Module, dl:DataLoader, pbar:Optional[PBar]=None, cb_handler:Optional[CallbackHandler]=None,
|
44 |
+
activ:nn.Module=None, loss_func:OptLossFunc=None, n_batch:Optional[int]=None) -> List[Tensor]:
|
45 |
+
"Tuple of predictions and targets, and optional losses (if `loss_func`) using `dl`, max batches `n_batch`."
|
46 |
+
res = [torch.cat(o).cpu() for o in
|
47 |
+
zip(*validate(model, dl, cb_handler=cb_handler, pbar=pbar, average=False, n_batch=n_batch))]
|
48 |
+
if loss_func is not None:
|
49 |
+
with NoneReduceOnCPU(loss_func) as lf: res.append(lf(res[0], res[1]))
|
50 |
+
if activ is not None: res[0] = activ(res[0])
|
51 |
+
return res
|
52 |
+
|
53 |
+
def validate(model:nn.Module, dl:DataLoader, loss_func:OptLossFunc=None, cb_handler:Optional[CallbackHandler]=None,
|
54 |
+
pbar:Optional[PBar]=None, average=True, n_batch:Optional[int]=None)->Iterator[Tuple[Union[Tensor,int],...]]:
|
55 |
+
"Calculate `loss_func` of `model` on `dl` in evaluation mode."
|
56 |
+
model.eval()
|
57 |
+
with torch.no_grad():
|
58 |
+
val_losses,nums = [],[]
|
59 |
+
if cb_handler: cb_handler.set_dl(dl)
|
60 |
+
for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
|
61 |
+
if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)
|
62 |
+
val_loss = loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler)
|
63 |
+
val_losses.append(val_loss)
|
64 |
+
if not is_listy(yb): yb = [yb]
|
65 |
+
nums.append(first_el(yb).shape[0])
|
66 |
+
if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break
|
67 |
+
if n_batch and (len(nums)>=n_batch): break
|
68 |
+
nums = np.array(nums, dtype=np.float32)
|
69 |
+
if average: return (to_np(torch.stack(val_losses)) * nums).sum() / nums.sum()
|
70 |
+
else: return val_losses
|
71 |
+
|
72 |
+
def train_epoch(model:nn.Module, dl:DataLoader, opt:optim.Optimizer, loss_func:LossFunction)->None:
|
73 |
+
"Simple training of `model` for 1 epoch of `dl` using optim `opt` and loss function `loss_func`."
|
74 |
+
model.train()
|
75 |
+
for xb,yb in dl:
|
76 |
+
loss = loss_func(model(xb), yb)
|
77 |
+
loss.backward()
|
78 |
+
opt.step()
|
79 |
+
opt.zero_grad()
|
80 |
+
|
81 |
+
@dataclass
|
82 |
+
class BasicLearner():
|
83 |
+
model:nn.Module
|
84 |
+
loss_func:LossFunction
|
85 |
+
opt:optim.Optimizer
|
86 |
+
data:DataBunch
|
87 |
+
|
88 |
+
def fit(epochs:int, learn:BasicLearner, callbacks:Optional[CallbackList]=None, metrics:OptMetrics=None, batch_multiplier:int=1)->None:
|
89 |
+
"Fit the `model` on `data` and learn using `loss_func` and `opt`."
|
90 |
+
assert len(learn.data.train_dl) != 0, f"""Your training dataloader is empty, can't train a model.
|
91 |
+
Use a smaller batch size (batch size={learn.data.train_dl.batch_size} for {len(learn.data.train_dl.dataset)} elements)."""
|
92 |
+
cb_handler = CallbackHandler(callbacks, metrics)
|
93 |
+
pbar = master_bar(range(epochs))
|
94 |
+
cb_handler.on_train_begin(epochs, pbar=pbar, metrics=metrics)
|
95 |
+
|
96 |
+
exception=False
|
97 |
+
try:
|
98 |
+
for epoch in pbar:
|
99 |
+
learn.model.train()
|
100 |
+
cb_handler.set_dl(learn.data.train_dl)
|
101 |
+
cb_handler.on_epoch_begin()
|
102 |
+
count = [batch_multiplier]
|
103 |
+
for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
|
104 |
+
xb, yb = cb_handler.on_batch_begin(xb, yb)
|
105 |
+
loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler, count=count, batch_multiplier=batch_multiplier)
|
106 |
+
if cb_handler.on_batch_end(loss): break
|
107 |
+
|
108 |
+
if not cb_handler.skip_validate and not learn.data.empty_val:
|
109 |
+
val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,
|
110 |
+
cb_handler=cb_handler, pbar=pbar)
|
111 |
+
else: val_loss=None
|
112 |
+
if cb_handler.on_epoch_end(val_loss): break
|
113 |
+
except Exception as e:
|
114 |
+
exception = e
|
115 |
+
raise
|
116 |
+
finally: cb_handler.on_train_end(exception)
|
117 |
+
|
118 |
+
loss_func_name2activ = {'cross_entropy_loss': F.softmax, 'nll_loss': torch.exp, 'poisson_nll_loss': torch.exp,
|
119 |
+
'kl_div_loss': torch.exp, 'bce_with_logits_loss': torch.sigmoid, 'cross_entropy': F.softmax,
|
120 |
+
'kl_div': torch.exp, 'binary_cross_entropy_with_logits': torch.sigmoid,
|
121 |
+
}
|
122 |
+
|
123 |
+
def _loss_func_name2activ(name:str, axis:int=-1):
|
124 |
+
res = loss_func_name2activ[name]
|
125 |
+
if res == F.softmax: res = partial(F.softmax, dim=axis)
|
126 |
+
return res
|
127 |
+
|
128 |
+
def _loss_func2activ(loss_func):
|
129 |
+
if getattr(loss_func,'keywords',None):
|
130 |
+
if not loss_func.keywords.get('log_input', True): return
|
131 |
+
axis = getattr(loss_func, 'axis', -1)
|
132 |
+
# flattened loss
|
133 |
+
loss_func = getattr(loss_func, 'func', loss_func)
|
134 |
+
# could have a partial inside flattened loss! Duplicate on purpose.
|
135 |
+
loss_func = getattr(loss_func, 'func', loss_func)
|
136 |
+
cls_name = camel2snake(loss_func.__class__.__name__)
|
137 |
+
if cls_name == 'mix_up_loss':
|
138 |
+
loss_func = loss_func.crit
|
139 |
+
cls_name = camel2snake(loss_func.__class__.__name__)
|
140 |
+
if cls_name in loss_func_name2activ:
|
141 |
+
if cls_name == 'poisson_nll_loss' and (not getattr(loss_func, 'log_input', True)): return
|
142 |
+
return _loss_func_name2activ(cls_name, axis)
|
143 |
+
if getattr(loss_func,'__name__','') in loss_func_name2activ:
|
144 |
+
return _loss_func_name2activ(loss_func.__name__, axis)
|
145 |
+
return noop
|
146 |
+
|
147 |
+
@dataclass
|
148 |
+
class Learner():
|
149 |
+
"Trainer for `model` using `data` to minimize `loss_func` with optimizer `opt_func`."
|
150 |
+
data:DataBunch
|
151 |
+
model:nn.Module
|
152 |
+
opt_func:Callable=AdamW
|
153 |
+
loss_func:Callable=None
|
154 |
+
metrics:Collection[Callable]=None
|
155 |
+
true_wd:bool=True
|
156 |
+
bn_wd:bool=True
|
157 |
+
wd:Floats=defaults.wd
|
158 |
+
train_bn:bool=True
|
159 |
+
path:str = None
|
160 |
+
model_dir:PathOrStr = 'models'
|
161 |
+
callback_fns:Collection[Callable]=None
|
162 |
+
callbacks:Collection[Callback]=field(default_factory=list)
|
163 |
+
layer_groups:Collection[nn.Module]=None
|
164 |
+
add_time:bool=True
|
165 |
+
silent:bool=None
|
166 |
+
def __post_init__(self)->None:
|
167 |
+
"Setup path,metrics, callbacks and ensure model directory exists."
|
168 |
+
self.path = Path(ifnone(self.path, self.data.path))
|
169 |
+
self.model = self.model.to(self.data.device)
|
170 |
+
self.loss_func = self.loss_func or self.data.loss_func
|
171 |
+
self.metrics=listify(self.metrics)
|
172 |
+
if not self.layer_groups: self.layer_groups = [nn.Sequential(*flatten_model(self.model))]
|
173 |
+
self.callbacks = listify(self.callbacks)
|
174 |
+
if self.silent is None: self.silent = defaults.silent
|
175 |
+
self.callback_fns = [partial(Recorder, add_time=self.add_time, silent=self.silent)] + listify(self.callback_fns)
|
176 |
+
|
177 |
+
def init(self, init): apply_init(self.model, init)
|
178 |
+
|
179 |
+
def _test_writeable_path(self):
|
180 |
+
path = self.path/self.model_dir
|
181 |
+
try:
|
182 |
+
path.mkdir(parents=True, exist_ok=True)
|
183 |
+
tmp_file = get_tmp_file(path)
|
184 |
+
except OSError as e:
|
185 |
+
raise Exception(f"{e}\nCan't write to '{path}', set `learn.model_dir` attribute in Learner to a full libpath path that is writable") from None
|
186 |
+
os.remove(tmp_file)
|
187 |
+
|
188 |
+
def lr_range(self, lr:Union[float,slice])->np.ndarray:
|
189 |
+
"Build differential learning rates from `lr`."
|
190 |
+
if not isinstance(lr,slice): return lr
|
191 |
+
if lr.start: res = even_mults(lr.start, lr.stop, len(self.layer_groups))
|
192 |
+
else: res = [lr.stop/10]*(len(self.layer_groups)-1) + [lr.stop]
|
193 |
+
return np.array(res)
|
194 |
+
|
195 |
+
def fit(self, epochs:int, lr:Union[Floats,slice]=defaults.lr,
|
196 |
+
wd:Floats=None, callbacks:Collection[Callback]=None, batch_multiplier:int=1)->None:
|
197 |
+
"Fit the model on this learner with `lr` learning rate, `wd` weight decay for `epochs` with `callbacks`."
|
198 |
+
lr = self.lr_range(lr)
|
199 |
+
if wd is None: wd = self.wd
|
200 |
+
if not getattr(self, 'opt', False): self.create_opt(lr, wd)
|
201 |
+
else: self.opt.lr,self.opt.wd = lr,wd
|
202 |
+
callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
|
203 |
+
if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
|
204 |
+
fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks, batch_multiplier=batch_multiplier)
|
205 |
+
|
206 |
+
def create_opt(self, lr:Floats, wd:Floats=0.)->None:
|
207 |
+
"Create optimizer with `lr` learning rate and `wd` weight decay."
|
208 |
+
self.opt = OptimWrapper.create(self.opt_func, lr, self.layer_groups, wd=wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
|
209 |
+
|
210 |
+
def split(self, split_on:SplitFuncOrIdxList)->None:
|
211 |
+
"Split the model at `split_on`."
|
212 |
+
if isinstance(split_on,Callable): split_on = split_on(self.model)
|
213 |
+
self.layer_groups = split_model(self.model, split_on)
|
214 |
+
return self
|
215 |
+
|
216 |
+
def freeze_to(self, n:int)->None:
|
217 |
+
"Freeze layers up to layer group `n`."
|
218 |
+
for g in self.layer_groups[:n]:
|
219 |
+
for l in g:
|
220 |
+
if not self.train_bn or not isinstance(l, bn_types): requires_grad(l, False)
|
221 |
+
for g in self.layer_groups[n:]: requires_grad(g, True)
|
222 |
+
self.create_opt(defaults.lr)
|
223 |
+
|
224 |
+
def freeze(self)->None:
|
225 |
+
"Freeze up to last layer group."
|
226 |
+
assert(len(self.layer_groups)>1)
|
227 |
+
self.freeze_to(-1)
|
228 |
+
|
229 |
+
def unfreeze(self):
|
230 |
+
"Unfreeze entire model."
|
231 |
+
self.freeze_to(0)
|
232 |
+
|
233 |
+
def export(self, file:PathLikeOrBinaryStream='export.pkl', destroy=False):
|
234 |
+
"Export the state of the `Learner` in `self.path/file`. `file` can be file-like (file or buffer)"
|
235 |
+
if rank_distrib(): return # don't save if slave proc
|
236 |
+
args = ['opt_func', 'loss_func', 'metrics', 'true_wd', 'bn_wd', 'wd', 'train_bn', 'model_dir', 'callback_fns']
|
237 |
+
state = {a:getattr(self,a) for a in args}
|
238 |
+
state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}
|
239 |
+
#layer_groups -> need to find a way
|
240 |
+
#TO SEE: do we save model structure and weights separately?
|
241 |
+
with ModelOnCPU(self.model) as m:
|
242 |
+
state['model'] = m
|
243 |
+
xtra = dict(normalize=self.data.norm.keywords) if getattr(self.data, 'norm', False) else {}
|
244 |
+
state['data'] = self.data.valid_ds.get_state(**xtra)
|
245 |
+
state['cls'] = self.__class__
|
246 |
+
try_save(state, self.path, file)
|
247 |
+
if destroy: self.destroy()
|
248 |
+
|
249 |
+
def save(self, file:PathLikeOrBinaryStream=None, return_path:bool=False, with_opt:bool=True):
|
250 |
+
"Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
|
251 |
+
if is_pathlike(file): self._test_writeable_path()
|
252 |
+
if rank_distrib(): return # don't save if slave proc
|
253 |
+
target = self.path/self.model_dir/f'{file}.pth' if is_pathlike(file) else file
|
254 |
+
if not hasattr(self, 'opt'): with_opt=False
|
255 |
+
if not with_opt: state = get_model(self.model).state_dict()
|
256 |
+
else: state = {'model': get_model(self.model).state_dict(), 'opt':self.opt.state_dict()}
|
257 |
+
torch.save(state, target)
|
258 |
+
if return_path: return target
|
259 |
+
|
260 |
+
def dl(self, ds_type:DatasetType=DatasetType.Valid):
|
261 |
+
"Return DataLoader for DatasetType `ds_type`."
|
262 |
+
return self.data.dl(ds_type)
|
263 |
+
|
264 |
+
def load(self, file:PathLikeOrBinaryStream=None, device:torch.device=None, strict:bool=True,
|
265 |
+
with_opt:bool=None, purge:bool=True, remove_module:bool=False):
|
266 |
+
"Load model and optimizer state (if `with_opt`) `file` from `self.model_dir` using `device`. `file` can be file-like (file or buffer)"
|
267 |
+
if purge: self.purge(clear_opt=ifnone(with_opt, False))
|
268 |
+
if device is None: device = self.data.device
|
269 |
+
elif isinstance(device, int): device = torch.device('cuda', device)
|
270 |
+
source = self.path/self.model_dir/f'{file}.pth' if is_pathlike(file) else file
|
271 |
+
state = torch.load(source, map_location=device)
|
272 |
+
if set(state.keys()) == {'model', 'opt'}:
|
273 |
+
model_state = state['model']
|
274 |
+
if remove_module: model_state = remove_module_load(model_state)
|
275 |
+
get_model(self.model).load_state_dict(model_state, strict=strict)
|
276 |
+
if ifnone(with_opt,True):
|
277 |
+
if not hasattr(self, 'opt'): self.create_opt(defaults.lr, self.wd)
|
278 |
+
try: self.opt.load_state_dict(state['opt'])
|
279 |
+
except: pass
|
280 |
+
else:
|
281 |
+
if with_opt: warn("Saved filed doesn't contain an optimizer state.")
|
282 |
+
if remove_module: state = remove_module_load(state)
|
283 |
+
get_model(self.model).load_state_dict(state, strict=strict)
|
284 |
+
del state
|
285 |
+
gc.collect()
|
286 |
+
return self
|
287 |
+
|
288 |
+
def destroy(self):
|
289 |
+
"Free the Learner internals, leaving just an empty shell that consumes no memory"
|
290 |
+
|
291 |
+
class ZombieLearner(Learner):
|
292 |
+
msg = "this object has been destroyed"
|
293 |
+
def __getattr__(self, item): print(ZombieLearner.msg); return None
|
294 |
+
def destroyed(*args, **kwargs): print(ZombieLearner.msg)
|
295 |
+
|
296 |
+
attrs = [k for k in self.__dict__.keys() if not k.startswith("__")]
|
297 |
+
for a in attrs: delattr(self, a)
|
298 |
+
# the instance methods can still be called, but will just give a message
|
299 |
+
methods = [k for k in dir(self) if not k.startswith("__") and inspect.isroutine(getattr(self, k))]
|
300 |
+
for m in methods: setattr(self, m, ZombieLearner.destroyed)
|
301 |
+
self.__class__ = ZombieLearner
|
302 |
+
gc.collect()
|
303 |
+
print("this Learner object self-destroyed - it still exists, but no longer usable")
|
304 |
+
|
305 |
+
def purge(self, clear_opt:bool=True):
|
306 |
+
"Purge the `Learner` of all cached attributes to release some GPU memory."
|
307 |
+
self._test_writeable_path()
|
308 |
+
attrs_all = [k for k in self.__dict__.keys() if not k.startswith("__")]
|
309 |
+
attrs_pkl = ['bn_wd', 'callback_fns', 'layer_groups', 'loss_func', 'metrics', 'model',
|
310 |
+
'model_dir', 'opt_func', 'path', 'train_bn', 'true_wd', 'wd']
|
311 |
+
# +callbacks: get pickled too, but not directly
|
312 |
+
attrs_keep = ['data', 'recorder']
|
313 |
+
attrs_del = list(set(attrs_all) - set(attrs_keep))
|
314 |
+
state = {a:getattr(self, a) for a in attrs_pkl}
|
315 |
+
state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}
|
316 |
+
if hasattr(self, 'opt'): state['opt'] = self.opt.get_state()
|
317 |
+
|
318 |
+
tmp_file = get_tmp_file(self.path/self.model_dir)
|
319 |
+
torch.save(state, open(tmp_file, 'wb'))
|
320 |
+
for a in attrs_del: delattr(self, a)
|
321 |
+
gc.collect()
|
322 |
+
state = torch.load(tmp_file)
|
323 |
+
os.remove(tmp_file)
|
324 |
+
|
325 |
+
for a in attrs_pkl: setattr(self, a, state[a])
|
326 |
+
cb_state = state.pop('cb_state')
|
327 |
+
self.callbacks = [load_callback(c,s, self) for c,s in cb_state.items()]
|
328 |
+
if not clear_opt and 'opt' in state:
|
329 |
+
try: self.opt = OptimWrapper.load_with_state_and_layer_group(state['opt'], self.layer_groups)
|
330 |
+
except: warn("Wasn't able to properly load the optimizer state again.")
|
331 |
+
del state
|
332 |
+
gc.collect()
|
333 |
+
return self
|
334 |
+
|
335 |
+
def get_preds(self, ds_type:DatasetType=DatasetType.Valid, with_loss:bool=False, n_batch:Optional[int]=None,
|
336 |
+
pbar:Optional[PBar]=None) -> List[Tensor]:
|
337 |
+
"Return predictions and targets on `ds_type` dataset."
|
338 |
+
lf = self.loss_func if with_loss else None
|
339 |
+
return get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
|
340 |
+
activ=_loss_func2activ(self.loss_func), loss_func=lf, n_batch=n_batch, pbar=pbar)
|
341 |
+
|
342 |
+
def pred_batch(self, ds_type:DatasetType=DatasetType.Valid, batch:Tuple=None, reconstruct:bool=False, with_dropout:bool=False) -> List[Tensor]:
|
343 |
+
with torch.no_grad():
|
344 |
+
training = self.model.training
|
345 |
+
self.model.train(False)
|
346 |
+
"Return output of the model on one batch from `ds_type` dataset."
|
347 |
+
if batch is not None: xb,yb = batch
|
348 |
+
else: xb,yb = self.data.one_batch(ds_type, detach=False, denorm=False)
|
349 |
+
cb_handler = CallbackHandler(self.callbacks)
|
350 |
+
xb,yb = cb_handler.on_batch_begin(xb,yb, train=False)
|
351 |
+
if not with_dropout:
|
352 |
+
preds = loss_batch(self.model.eval(), xb, yb, cb_handler=cb_handler)
|
353 |
+
else:
|
354 |
+
preds = loss_batch(self.model.eval().apply(self.apply_dropout), xb, yb, cb_handler=cb_handler)
|
355 |
+
res = _loss_func2activ(self.loss_func)(preds[0])
|
356 |
+
self.model.train(training)
|
357 |
+
if not reconstruct: return res
|
358 |
+
res = res.detach().cpu()
|
359 |
+
ds = self.dl(ds_type).dataset
|
360 |
+
norm = getattr(self.data, 'norm', False)
|
361 |
+
if norm and norm.keywords.get('do_y',False):
|
362 |
+
res = self.data.denorm(res, do_x=True)
|
363 |
+
return [ds.reconstruct(o) for o in res]
|
364 |
+
|
365 |
+
def backward(self, item):
|
366 |
+
"Pass `item` through the model and computes the gradient. Useful if `backward_hooks` are attached."
|
367 |
+
xb,yb = self.data.one_item(item)
|
368 |
+
loss = loss_batch(self.model.eval(), xb, yb, self.loss_func, opt=FakeOptimizer(),
|
369 |
+
cb_handler=CallbackHandler(self.callbacks))
|
370 |
+
return loss
|
371 |
+
|
372 |
+
def predict(self, item:ItemBase, return_x:bool=False, batch_first:bool=True, with_dropout:bool=False, **kwargs):
|
373 |
+
"Return predicted class, label and probabilities for `item`."
|
374 |
+
batch = self.data.one_item(item)
|
375 |
+
res = self.pred_batch(batch=batch, with_dropout=with_dropout)
|
376 |
+
raw_pred,x = grab_idx(res,0,batch_first=batch_first),batch[0]
|
377 |
+
norm = getattr(self.data,'norm',False)
|
378 |
+
if norm:
|
379 |
+
x = self.data.denorm(x)
|
380 |
+
if norm.keywords.get('do_y',False): raw_pred = self.data.denorm(raw_pred)
|
381 |
+
ds = self.data.single_ds
|
382 |
+
pred = ds.y.analyze_pred(raw_pred, **kwargs)
|
383 |
+
x = ds.x.reconstruct(grab_idx(x, 0))
|
384 |
+
y = ds.y.reconstruct(pred, x) if has_arg(ds.y.reconstruct, 'x') else ds.y.reconstruct(pred)
|
385 |
+
return (x, y, pred, raw_pred) if return_x else (y, pred, raw_pred)
|
386 |
+
|
387 |
+
def validate(self, dl=None, callbacks=None, metrics=None):
|
388 |
+
"Validate on `dl` with potential `callbacks` and `metrics`."
|
389 |
+
dl = ifnone(dl, self.data.valid_dl)
|
390 |
+
metrics = ifnone(metrics, self.metrics)
|
391 |
+
cb_handler = CallbackHandler(self.callbacks + ifnone(callbacks, []), metrics)
|
392 |
+
cb_handler.on_epoch_begin()
|
393 |
+
val_metrics = validate(self.model, dl, self.loss_func, cb_handler)
|
394 |
+
cb_handler.on_epoch_end(val_metrics)
|
395 |
+
return cb_handler.state_dict['last_metrics']
|
396 |
+
|
397 |
+
def show_results(self, ds_type=DatasetType.Valid, rows:int=5, **kwargs):
|
398 |
+
"Show `rows` result of predictions on `ds_type` dataset."
|
399 |
+
#TODO: get read of has_arg x and split_kwargs_by_func if possible
|
400 |
+
#TODO: simplify this and refactor with pred_batch(...reconstruct=True)
|
401 |
+
n_items = rows ** 2 if self.data.train_ds.x._square_show_res else rows
|
402 |
+
if self.dl(ds_type).batch_size < n_items: n_items = self.dl(ds_type).batch_size
|
403 |
+
ds = self.dl(ds_type).dataset
|
404 |
+
self.callbacks.append(RecordOnCPU())
|
405 |
+
preds = self.pred_batch(ds_type)
|
406 |
+
*self.callbacks,rec_cpu = self.callbacks
|
407 |
+
x,y = rec_cpu.input,rec_cpu.target
|
408 |
+
norm = getattr(self.data,'norm',False)
|
409 |
+
if norm:
|
410 |
+
x = self.data.denorm(x)
|
411 |
+
if norm.keywords.get('do_y',False):
|
412 |
+
y = self.data.denorm(y, do_x=True)
|
413 |
+
preds = self.data.denorm(preds, do_x=True)
|
414 |
+
analyze_kwargs,kwargs = split_kwargs_by_func(kwargs, ds.y.analyze_pred)
|
415 |
+
preds = [ds.y.analyze_pred(grab_idx(preds, i), **analyze_kwargs) for i in range(n_items)]
|
416 |
+
xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)]
|
417 |
+
if has_arg(ds.y.reconstruct, 'x'):
|
418 |
+
ys = [ds.y.reconstruct(grab_idx(y, i), x=x) for i,x in enumerate(xs)]
|
419 |
+
zs = [ds.y.reconstruct(z, x=x) for z,x in zip(preds,xs)]
|
420 |
+
else :
|
421 |
+
ys = [ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)]
|
422 |
+
zs = [ds.y.reconstruct(z) for z in preds]
|
423 |
+
ds.x.show_xyzs(xs, ys, zs, **kwargs)
|
424 |
+
|
425 |
+
def apply_dropout(self, m):
|
426 |
+
"If a module contains 'dropout' in it's name, it will be switched to .train() mode."
|
427 |
+
if 'dropout' in m.__class__.__name__.lower(): m.train()
|
428 |
+
|
429 |
+
def predict_with_mc_dropout(self, item:ItemBase, with_dropout:bool=True, n_times=10, **kwargs):
|
430 |
+
"Make predictions with dropout turned on for n_times (default 10)."
|
431 |
+
return [self.predict(item, with_dropout=with_dropout) for _ in range(n_times)]
|
432 |
+
|
433 |
+
class RecordOnCPU(Callback):
|
434 |
+
"Store the `input` and `target` going through the model on the CPU."
|
435 |
+
def on_batch_begin(self, last_input,last_target,**kwargs):
|
436 |
+
self.input,self.target = to_cpu(last_input),to_cpu(last_target)
|
437 |
+
|
438 |
+
class LearnerCallback(Callback):
|
439 |
+
"Base class for creating callbacks for a `Learner`."
|
440 |
+
def __init__(self, learn):
|
441 |
+
self._learn = weakref.ref(learn)
|
442 |
+
self.exclude,self.not_min = ['_learn'],[]
|
443 |
+
setattr(self.learn, self.cb_name, self)
|
444 |
+
|
445 |
+
def __getattr__(self,k): return getattr(self.learn, k)
|
446 |
+
def __setstate__(self,data:Any): self.__dict__.update(data)
|
447 |
+
|
448 |
+
@property
|
449 |
+
def learn(self) -> Learner: return self._learn()
|
450 |
+
@learn.setter
|
451 |
+
def learn(self, learn: Learner) -> None: self._learn = weakref.ref(learn)
|
452 |
+
|
453 |
+
@property
|
454 |
+
def cb_name(self): return camel2snake(self.__class__.__name__)
|
455 |
+
|
456 |
+
class Recorder(LearnerCallback):
|
457 |
+
"A `LearnerCallback` that records epoch, loss, opt and metric data during training."
|
458 |
+
_order=-10
|
459 |
+
def __init__(self, learn:Learner, add_time:bool=True, silent:bool=False):
|
460 |
+
super().__init__(learn)
|
461 |
+
self.opt = self.learn.opt
|
462 |
+
self.train_dl = self.learn.data.train_dl
|
463 |
+
self.no_val,self.silent,self.add_time = False,silent,add_time
|
464 |
+
|
465 |
+
def on_train_begin(self, pbar:PBar, metrics_names:Collection[str], **kwargs:Any)->None:
|
466 |
+
"Initialize recording status at beginning of training."
|
467 |
+
self.pbar = pbar
|
468 |
+
self.names = ['epoch', 'train_loss'] if self.no_val else ['epoch', 'train_loss', 'valid_loss']
|
469 |
+
self.metrics_names = metrics_names
|
470 |
+
if hasattr(self, '_added_met_names'): self.metrics_names += self._added_met_names
|
471 |
+
self.names += self.metrics_names
|
472 |
+
if self.add_time: self.names.append('time')
|
473 |
+
if not self.silent: self.pbar.write(self.names, table=True)
|
474 |
+
self.losses,self.val_losses,self.lrs,self.moms,self.metrics,self.nb_batches = [],[],[],[],[],[]
|
475 |
+
|
476 |
+
def on_epoch_begin(self, **kwargs:Any)->None:
|
477 |
+
if self.add_time: self.start_epoch = time()
|
478 |
+
|
479 |
+
def on_batch_begin(self, train, **kwargs:Any)->None:
|
480 |
+
"Record learning rate and momentum at beginning of batch."
|
481 |
+
if train:
|
482 |
+
self.lrs.append(self.opt.lr)
|
483 |
+
self.moms.append(self.opt.mom)
|
484 |
+
|
485 |
+
def on_backward_begin(self, smooth_loss:Tensor, **kwargs:Any)->None:
|
486 |
+
"Record the loss before any other callback has a chance to modify it."
|
487 |
+
self.losses.append(smooth_loss)
|
488 |
+
if self.pbar is not None and hasattr(self.pbar,'child'):
|
489 |
+
self.pbar.child.comment = f'{smooth_loss:.4f}'
|
490 |
+
|
491 |
+
def on_epoch_end(self, epoch:int, num_batch:int, smooth_loss:Tensor,
|
492 |
+
last_metrics=MetricsList, **kwargs:Any)->bool:
|
493 |
+
"Save epoch info: num_batch, smooth_loss, metrics."
|
494 |
+
self.nb_batches.append(num_batch)
|
495 |
+
if last_metrics is not None: self.val_losses.append(last_metrics[0])
|
496 |
+
else: last_metrics = [] if self.no_val else [None]
|
497 |
+
if len(last_metrics) > 1: self.metrics.append(last_metrics[1:])
|
498 |
+
self.format_stats([epoch, smooth_loss] + last_metrics)
|
499 |
+
|
500 |
+
def format_stats(self, stats:TensorOrNumList)->None:
|
501 |
+
"Format stats before printing."
|
502 |
+
str_stats = []
|
503 |
+
for name,stat in zip(self.names,stats):
|
504 |
+
str_stats.append('#na#' if stat is None else str(stat) if isinstance(stat, int) else f'{stat:.6f}')
|
505 |
+
if self.add_time: str_stats.append(format_time(time() - self.start_epoch))
|
506 |
+
if not self.silent: self.pbar.write(str_stats, table=True)
|
507 |
+
|
508 |
+
def add_metric_names(self, names):
|
509 |
+
"Add `names` to the inner metric names."
|
510 |
+
if hasattr(self, '_added_met_names'): self._added_met_names += names
|
511 |
+
else: self._added_met_names = names
|
512 |
+
|
513 |
+
def plot_lr(self, show_moms=False, skip_start:int=0, skip_end:int=0, return_fig:bool=None)->Optional[plt.Figure]:
|
514 |
+
"Plot learning rate, `show_moms` to include momentum."
|
515 |
+
lrs = self._split_list(self.lrs, skip_start, skip_end)
|
516 |
+
iterations = self._split_list(range_of(self.lrs), skip_start, skip_end)
|
517 |
+
if show_moms:
|
518 |
+
moms = self._split_list(self.moms, skip_start, skip_end)
|
519 |
+
fig, axs = plt.subplots(1,2, figsize=(12,4))
|
520 |
+
axs[0].plot(iterations, lrs)
|
521 |
+
axs[0].set_xlabel('Iterations')
|
522 |
+
axs[0].set_ylabel('Learning Rate')
|
523 |
+
axs[1].plot(iterations, moms)
|
524 |
+
axs[1].set_xlabel('Iterations')
|
525 |
+
axs[1].set_ylabel('Momentum')
|
526 |
+
else:
|
527 |
+
fig, ax = plt.subplots()
|
528 |
+
ax.plot(iterations, lrs)
|
529 |
+
ax.set_xlabel('Iterations')
|
530 |
+
ax.set_ylabel('Learning Rate')
|
531 |
+
if ifnone(return_fig, defaults.return_fig): return fig
|
532 |
+
if not IN_NOTEBOOK: plot_sixel(fig)
|
533 |
+
|
534 |
+
@staticmethod
|
535 |
+
def smoothen_by_spline(xs, ys, **kwargs):
|
536 |
+
xs = np.arange(len(ys))
|
537 |
+
spl = scipy.interpolate.UnivariateSpline(xs, ys, **kwargs)
|
538 |
+
ys = spl(xs)
|
539 |
+
return ys
|
540 |
+
|
541 |
+
def plot(self, skip_start:int=10, skip_end:int=5, suggestion:bool=False, return_fig:bool=None,
|
542 |
+
**kwargs)->Optional[plt.Figure]:
|
543 |
+
"Plot learning rate and losses, trimmed between `skip_start` and `skip_end`. Optionally plot and return min gradient"
|
544 |
+
lrs = self._split_list(self.lrs, skip_start, skip_end)
|
545 |
+
losses = self._split_list(self.losses, skip_start, skip_end)
|
546 |
+
losses = [x.item() for x in losses]
|
547 |
+
if 'k' in kwargs: losses = self.smoothen_by_spline(lrs, losses, **kwargs)
|
548 |
+
fig, ax = plt.subplots(1,1)
|
549 |
+
ax.plot(lrs, losses)
|
550 |
+
ax.set_ylabel("Loss")
|
551 |
+
ax.set_xlabel("Learning Rate")
|
552 |
+
ax.set_xscale('log')
|
553 |
+
ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e'))
|
554 |
+
if suggestion:
|
555 |
+
try: mg = (np.gradient(np.array(losses))).argmin()
|
556 |
+
except:
|
557 |
+
print("Failed to compute the gradients, there might not be enough points.")
|
558 |
+
return
|
559 |
+
print(f"Min numerical gradient: {lrs[mg]:.2E}")
|
560 |
+
ax.plot(lrs[mg],losses[mg],markersize=10,marker='o',color='red')
|
561 |
+
self.min_grad_lr = lrs[mg]
|
562 |
+
ml = np.argmin(losses)
|
563 |
+
print(f"Min loss divided by 10: {lrs[ml]/10:.2E}")
|
564 |
+
if ifnone(return_fig, defaults.return_fig): return fig
|
565 |
+
if not IN_NOTEBOOK: plot_sixel(fig)
|
566 |
+
|
567 |
+
def plot_losses(self, skip_start:int=0, skip_end:int=0, return_fig:bool=None)->Optional[plt.Figure]:
|
568 |
+
"Plot training and validation losses."
|
569 |
+
fig, ax = plt.subplots(1,1)
|
570 |
+
losses = self._split_list(self.losses, skip_start, skip_end)
|
571 |
+
iterations = self._split_list(range_of(self.losses), skip_start, skip_end)
|
572 |
+
ax.plot(iterations, losses, label='Train')
|
573 |
+
val_iter = self._split_list_val(np.cumsum(self.nb_batches), skip_start, skip_end)
|
574 |
+
val_losses = self._split_list_val(self.val_losses, skip_start, skip_end)
|
575 |
+
ax.plot(val_iter, val_losses, label='Validation')
|
576 |
+
ax.set_ylabel('Loss')
|
577 |
+
ax.set_xlabel('Batches processed')
|
578 |
+
ax.legend()
|
579 |
+
if ifnone(return_fig, defaults.return_fig): return fig
|
580 |
+
if not IN_NOTEBOOK: plot_sixel(fig)
|
581 |
+
|
582 |
+
def plot_metrics(self, skip_start:int=0, skip_end:int=0, return_fig:bool=None)->Optional[plt.Figure]:
|
583 |
+
"Plot metrics collected during training."
|
584 |
+
assert len(self.metrics) != 0, "There are no metrics to plot."
|
585 |
+
fig, axes = plt.subplots(len(self.metrics[0]),1,figsize=(6, 4*len(self.metrics[0])))
|
586 |
+
val_iter = self._split_list_val(np.cumsum(self.nb_batches), skip_start, skip_end)
|
587 |
+
axes = axes.flatten() if len(self.metrics[0]) != 1 else [axes]
|
588 |
+
for i, ax in enumerate(axes):
|
589 |
+
values = [met[i] for met in self.metrics]
|
590 |
+
values = self._split_list_val(values, skip_start, skip_end)
|
591 |
+
ax.plot(val_iter, values)
|
592 |
+
ax.set_ylabel(str(self.metrics_names[i]))
|
593 |
+
ax.set_xlabel('Batches processed')
|
594 |
+
if ifnone(return_fig, defaults.return_fig): return fig
|
595 |
+
if not IN_NOTEBOOK: plot_sixel(fig)
|
596 |
+
|
597 |
+
def _split_list(self, vals:Collection[float], skip_start:int, skip_end:int):
|
598 |
+
return vals[skip_start:-skip_end] if skip_end > 0 else vals[skip_start:]
|
599 |
+
|
600 |
+
def _split_list_val(self, vals:Collection[float], skip_start:int, skip_end:int):
|
601 |
+
val_iter = np.cumsum(self.nb_batches)
|
602 |
+
start_val = (val_iter - skip_start >= 0).nonzero()[0].min()
|
603 |
+
end_val = (val_iter[-1] - val_iter - skip_end >= 0).nonzero()[0].max()+1
|
604 |
+
return vals[start_val:end_val] if skip_end > 0 else vals[start_val:]
|
605 |
+
|
606 |
+
class FakeOptimizer():
|
607 |
+
def step(self): pass
|
608 |
+
def zero_grad(self): pass
|
609 |
+
|
610 |
+
def load_callback(class_func, state, learn:Learner):
|
611 |
+
init_kwargs, others = split_kwargs_by_func(state, class_func.__init__)
|
612 |
+
res = class_func(learn, **init_kwargs) if issubclass(class_func, LearnerCallback) else class_func(**init_kwargs)
|
613 |
+
for k,v in others.items(): setattr(res, k, v)
|
614 |
+
return res
|
615 |
+
|
616 |
+
def load_learner(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', test:ItemList=None, **db_kwargs):
|
617 |
+
"Load a `Learner` object saved with `export_state` in `path/file` with empty data, optionally add `test` and load on `cpu`. `file` can be file-like (file or buffer)"
|
618 |
+
source = Path(path)/file if is_pathlike(file) else file
|
619 |
+
state = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source)
|
620 |
+
model = state.pop('model')
|
621 |
+
src = LabelLists.load_state(path, state.pop('data'))
|
622 |
+
if test is not None: src.add_test(test)
|
623 |
+
data = src.databunch(**db_kwargs)
|
624 |
+
cb_state = state.pop('cb_state')
|
625 |
+
clas_func = state.pop('cls')
|
626 |
+
res = clas_func(data, model, **state)
|
627 |
+
res.callback_fns = state['callback_fns'] #to avoid duplicates
|
628 |
+
res.callbacks = [load_callback(c,s, res) for c,s in cb_state.items()]
|
629 |
+
return res
|
DeOldify/fastai/basics.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .basic_train import *
|
2 |
+
from .callback import *
|
3 |
+
from .core import *
|
4 |
+
from .basic_data import *
|
5 |
+
from .data_block import *
|
6 |
+
from .layers import *
|
7 |
+
from .metrics import *
|
8 |
+
from .torch_core import *
|
9 |
+
from .train import *
|
10 |
+
from .datasets import *
|
11 |
+
from .version import *
|
12 |
+
from . import callbacks
|
13 |
+
|
14 |
+
"""
|
15 |
+
from . import core,torch_core,basic_data,basic_train,callback,data_block,layers,metrics,train,datasets,callbacks
|
16 |
+
|
17 |
+
__all__ = [o for o in dir(core) if not o.startswith('_')]
|
18 |
+
__all__ += [o for o in dir(torch_core) if not o.startswith('_')]
|
19 |
+
__all__ += [*basic_train.__all__, *callback.__all__, 'core', 'torch_core', 'callbacks',
|
20 |
+
*basic_data.__all__, *data_block.__all__, *layers.__all__, *metrics.__all__,
|
21 |
+
*train.__all__, *datasets.__all__, '__version__']
|
22 |
+
"""
|
23 |
+
|
24 |
+
try: from .gen_doc.nbdoc import doc
|
25 |
+
except: pass # Optional if jupyter is present
|
26 |
+
#__all__.append('doc')
|
27 |
+
|
28 |
+
__all__ = [o for o in dir(sys.modules[__name__]) if not o.startswith('_')] + ['__version__']
|
29 |
+
|
DeOldify/fastai/callback.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Callbacks provides extensibility to the `basic_train` loop. See `train` for examples of custom callbacks."
|
2 |
+
from .basic_data import *
|
3 |
+
from .torch_core import *
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
__all__ = ['AverageMetric', 'Callback', 'CallbackHandler', 'OptimWrapper', 'SmoothenValue', 'Scheduler', 'annealing_cos', 'CallbackList',
|
7 |
+
'annealing_exp', 'annealing_linear', 'annealing_no', 'annealing_poly']
|
8 |
+
|
9 |
+
class OptimWrapper():
|
10 |
+
"Basic wrapper around `opt` to simplify hyper-parameters changes."
|
11 |
+
def __init__(self, opt:optim.Optimizer, wd:Floats=0., true_wd:bool=False, bn_wd:bool=True):
|
12 |
+
assert not isinstance(opt, OptimWrapper)
|
13 |
+
self.opt,self.true_wd,self.bn_wd = opt,true_wd,bn_wd
|
14 |
+
self.opt_keys = list(self.opt.param_groups[0].keys())
|
15 |
+
self.opt_keys.remove('params')
|
16 |
+
self.read_defaults()
|
17 |
+
self.wd = wd
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def create(cls, opt_func:Union[type,Callable], lr:Union[float,Tuple,List], layer_groups:ModuleList, wd:Floats=0.,
|
21 |
+
true_wd:bool=False, bn_wd:bool=True)->optim.Optimizer:
|
22 |
+
"Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`."
|
23 |
+
split_params = split_no_wd_params(layer_groups)
|
24 |
+
opt = opt_func([{'params': p, 'lr':0} for p in split_params])
|
25 |
+
opt = cls(opt, wd=wd, true_wd=true_wd, bn_wd=bn_wd)
|
26 |
+
opt.lr,opt.opt_func = listify(lr, layer_groups),opt_func
|
27 |
+
return opt
|
28 |
+
|
29 |
+
def new(self, layer_groups:Collection[nn.Module], split_no_wd:bool=True):
|
30 |
+
"Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters."
|
31 |
+
opt_func = getattr(self, 'opt_func', self.opt.__class__)
|
32 |
+
res = self.create(opt_func, self.lr, layer_groups, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
|
33 |
+
res.mom,res.beta = self.mom,self.beta
|
34 |
+
return res
|
35 |
+
|
36 |
+
def new_with_params(self, param_groups:Collection[Collection[nn.Parameter]]):
|
37 |
+
"Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters."
|
38 |
+
opt_func = getattr(self, 'opt_func', self.opt.__class__)
|
39 |
+
opt = opt_func([{'params': p, 'lr':0} for p in param_groups])
|
40 |
+
opt = self.__class__(opt, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
|
41 |
+
opt.lr,opt.opt_func,opt.mom,opt.beta = self.lr,opt_func,self.mom,self.beta
|
42 |
+
return opt
|
43 |
+
|
44 |
+
def __repr__(self)->str:
|
45 |
+
return f'OptimWrapper over {repr(self.opt)}.\nTrue weight decay: {self.true_wd}'
|
46 |
+
|
47 |
+
#Pytorch optimizer methods
|
48 |
+
def step(self)->None:
|
49 |
+
"Set weight decay and step optimizer."
|
50 |
+
# weight decay outside of optimizer step (AdamW)
|
51 |
+
if self.true_wd:
|
52 |
+
for lr,wd,pg1,pg2 in zip(self._lr,self._wd,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
|
53 |
+
for p in pg1['params']: p.data.mul_(1 - wd*lr)
|
54 |
+
if self.bn_wd:
|
55 |
+
for p in pg2['params']: p.data.mul_(1 - wd*lr)
|
56 |
+
self.set_val('weight_decay', listify(0, self._wd))
|
57 |
+
self.opt.step()
|
58 |
+
|
59 |
+
def zero_grad(self)->None:
|
60 |
+
"Clear optimizer gradients."
|
61 |
+
self.opt.zero_grad()
|
62 |
+
|
63 |
+
#Passthrough to the inner opt.
|
64 |
+
def __getattr__(self, k:str)->Any: return getattr(self.opt, k, None)
|
65 |
+
def __setstate__(self,data:Any): self.__dict__.update(data)
|
66 |
+
|
67 |
+
def clear(self):
|
68 |
+
"Reset the state of the inner optimizer."
|
69 |
+
sd = self.state_dict()
|
70 |
+
sd['state'] = {}
|
71 |
+
self.load_state_dict(sd)
|
72 |
+
|
73 |
+
@property
|
74 |
+
def n_params(self): return sum([len(pg['params']) for pg in self.opt.param_groups])
|
75 |
+
|
76 |
+
#Hyperparameters as properties
|
77 |
+
@property
|
78 |
+
def lr(self)->float: return self._lr[-1]
|
79 |
+
@lr.setter
|
80 |
+
def lr(self, val:float)->None:
|
81 |
+
self._lr = self.set_val('lr', listify(val, self._lr))
|
82 |
+
|
83 |
+
@property
|
84 |
+
def mom(self)->float:return self._mom[-1]
|
85 |
+
@mom.setter
|
86 |
+
def mom(self, val:float)->None:
|
87 |
+
if 'momentum' in self.opt_keys: self.set_val('momentum', listify(val, self._mom))
|
88 |
+
elif 'betas' in self.opt_keys: self.set_val('betas', (listify(val, self._mom), self._beta))
|
89 |
+
self._mom = listify(val, self._mom)
|
90 |
+
|
91 |
+
@property
|
92 |
+
def beta(self)->float: return None if self._beta is None else self._beta[-1]
|
93 |
+
@beta.setter
|
94 |
+
def beta(self, val:float)->None:
|
95 |
+
"Set beta (or alpha as makes sense for given optimizer)."
|
96 |
+
if val is None: return
|
97 |
+
if 'betas' in self.opt_keys: self.set_val('betas', (self._mom, listify(val, self._beta)))
|
98 |
+
elif 'alpha' in self.opt_keys: self.set_val('alpha', listify(val, self._beta))
|
99 |
+
self._beta = listify(val, self._beta)
|
100 |
+
|
101 |
+
@property
|
102 |
+
def wd(self)->float: return self._wd[-1]
|
103 |
+
@wd.setter
|
104 |
+
def wd(self, val:float)->None:
|
105 |
+
"Set weight decay."
|
106 |
+
if not self.true_wd: self.set_val('weight_decay', listify(val, self._wd), bn_groups=self.bn_wd)
|
107 |
+
self._wd = listify(val, self._wd)
|
108 |
+
|
109 |
+
#Helper functions
|
110 |
+
def read_defaults(self)->None:
|
111 |
+
"Read the values inside the optimizer for the hyper-parameters."
|
112 |
+
self._beta = None
|
113 |
+
if 'lr' in self.opt_keys: self._lr = self.read_val('lr')
|
114 |
+
if 'momentum' in self.opt_keys: self._mom = self.read_val('momentum')
|
115 |
+
if 'alpha' in self.opt_keys: self._beta = self.read_val('alpha')
|
116 |
+
if 'betas' in self.opt_keys: self._mom,self._beta = self.read_val('betas')
|
117 |
+
if 'weight_decay' in self.opt_keys: self._wd = self.read_val('weight_decay')
|
118 |
+
reserved_names = ['params', 'lr', 'momentum', 'alpha', 'betas', 'weight_decay']
|
119 |
+
stat_names = [n for n in self.opt_keys if n not in reserved_names]
|
120 |
+
self._stats = {n:self.read_val(n) for n in stat_names}
|
121 |
+
|
122 |
+
def get_stat(self, name:str)->float:
|
123 |
+
if name in ['lr', 'mom', 'beta', 'wd']: return getattr(self, name)
|
124 |
+
else: return self._stats[name][-1]
|
125 |
+
def set_stat(self, name:str, value:Union[float, Collection[float]])->None:
|
126 |
+
if name in ['lr', 'mom', 'beta', 'wd']: setattr(self, name, value)
|
127 |
+
else:
|
128 |
+
val = listify(value, self._stats[name])
|
129 |
+
self.set_val(name, val)
|
130 |
+
self._stats[name] = val
|
131 |
+
|
132 |
+
def set_val(self, key:str, val:Any, bn_groups:bool=True)->Any:
|
133 |
+
"Set `val` inside the optimizer dictionary at `key`."
|
134 |
+
if is_tuple(val): val = [(v1,v2) for v1,v2 in zip(*val)]
|
135 |
+
for v,pg1,pg2 in zip(val,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
|
136 |
+
pg1[key] = v
|
137 |
+
if bn_groups: pg2[key] = v
|
138 |
+
return val
|
139 |
+
|
140 |
+
def read_val(self, key:str) -> Union[List[float],Tuple[List[float],List[float]]]:
|
141 |
+
"Read a hyperparameter `key` in the optimizer dictionary."
|
142 |
+
val = [pg[key] for pg in self.opt.param_groups[::2]]
|
143 |
+
if is_tuple(val[0]): val = [o[0] for o in val], [o[1] for o in val]
|
144 |
+
return val
|
145 |
+
|
146 |
+
def get_state(self):
|
147 |
+
"Return the inner state minus the layer groups."
|
148 |
+
return {'opt_state':self.opt.state_dict(), 'lr':self._lr, 'wd':self._wd, 'beta':self._beta, 'mom':self._mom,
|
149 |
+
'opt_func':self.opt_func, 'true_wd':self.true_wd, 'bn_wd':self.bn_wd}
|
150 |
+
|
151 |
+
@classmethod
|
152 |
+
def load_with_state_and_layer_group(cls, state:dict, layer_groups:Collection[nn.Module]):
|
153 |
+
res = cls.create(state['opt_func'], state['lr'], layer_groups, wd=state['wd'], true_wd=state['true_wd'],
|
154 |
+
bn_wd=state['bn_wd'])
|
155 |
+
res._mom,res._beta = state['mom'],state['beta']
|
156 |
+
res.load_state_dict(state['opt_state'])
|
157 |
+
return res
|
158 |
+
|
159 |
+
class Callback():
|
160 |
+
"Base class for callbacks that want to record values, dynamically change learner params, etc."
|
161 |
+
_order=0
|
162 |
+
def on_train_begin(self, **kwargs:Any)->None:
|
163 |
+
"To initialize constants in the callback."
|
164 |
+
pass
|
165 |
+
def on_epoch_begin(self, **kwargs:Any)->None:
|
166 |
+
"At the beginning of each epoch."
|
167 |
+
pass
|
168 |
+
def on_batch_begin(self, **kwargs:Any)->None:
|
169 |
+
"Set HP before the output and loss are computed."
|
170 |
+
pass
|
171 |
+
def on_loss_begin(self, **kwargs:Any)->None:
|
172 |
+
"Called after forward pass but before loss has been computed."
|
173 |
+
pass
|
174 |
+
def on_backward_begin(self, **kwargs:Any)->None:
|
175 |
+
"Called after the forward pass and the loss has been computed, but before backprop."
|
176 |
+
pass
|
177 |
+
def on_backward_end(self, **kwargs:Any)->None:
|
178 |
+
"Called after backprop but before optimizer step. Useful for true weight decay in AdamW."
|
179 |
+
pass
|
180 |
+
def on_step_end(self, **kwargs:Any)->None:
|
181 |
+
"Called after the step of the optimizer but before the gradients are zeroed."
|
182 |
+
pass
|
183 |
+
def on_batch_end(self, **kwargs:Any)->None:
|
184 |
+
"Called at the end of the batch."
|
185 |
+
pass
|
186 |
+
def on_epoch_end(self, **kwargs:Any)->None:
|
187 |
+
"Called at the end of an epoch."
|
188 |
+
pass
|
189 |
+
def on_train_end(self, **kwargs:Any)->None:
|
190 |
+
"Useful for cleaning up things and saving files/models."
|
191 |
+
pass
|
192 |
+
def jump_to_epoch(self, epoch)->None:
|
193 |
+
"To resume training at `epoch` directly."
|
194 |
+
pass
|
195 |
+
|
196 |
+
def get_state(self, minimal:bool=True):
|
197 |
+
"Return the inner state of the `Callback`, `minimal` or not."
|
198 |
+
to_remove = ['exclude', 'not_min'] + getattr(self, 'exclude', []).copy()
|
199 |
+
if minimal: to_remove += getattr(self, 'not_min', []).copy()
|
200 |
+
return {k:v for k,v in self.__dict__.items() if k not in to_remove}
|
201 |
+
|
202 |
+
def __repr__(self):
|
203 |
+
attrs = func_args(self.__init__)
|
204 |
+
to_remove = getattr(self, 'exclude', [])
|
205 |
+
list_repr = [self.__class__.__name__] + [f'{k}: {getattr(self, k)}' for k in attrs if k != 'self' and k not in to_remove]
|
206 |
+
return '\n'.join(list_repr)
|
207 |
+
|
208 |
+
class SmoothenValue():
|
209 |
+
"Create a smooth moving average for a value (loss, etc) using `beta`."
|
210 |
+
def __init__(self, beta:float):
|
211 |
+
self.beta,self.n,self.mov_avg = beta,0,0
|
212 |
+
|
213 |
+
def add_value(self, val:float)->None:
|
214 |
+
"Add `val` to calculate updated smoothed value."
|
215 |
+
self.n += 1
|
216 |
+
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val
|
217 |
+
self.smooth = self.mov_avg / (1 - self.beta ** self.n)
|
218 |
+
|
219 |
+
CallbackList = Collection[Callback]
|
220 |
+
|
221 |
+
def _get_init_state(): return {'epoch':0, 'iteration':0, 'num_batch':0, 'skip_validate': False}
|
222 |
+
|
223 |
+
@dataclass
|
224 |
+
class CallbackHandler():
|
225 |
+
"Manage all of the registered `callbacks` and `metrics`, smoothing loss by momentum `beta`."
|
226 |
+
callbacks:CallbackList=None
|
227 |
+
metrics:CallbackList=None
|
228 |
+
beta:float=0.98
|
229 |
+
|
230 |
+
def __post_init__(self)->None:
|
231 |
+
"Initialize smoother and learning stats."
|
232 |
+
self.callbacks = ifnone(self.callbacks, [])
|
233 |
+
self.metrics = ifnone(self.metrics, [])
|
234 |
+
self.metrics = [(met if isinstance(met, Callback) else AverageMetric(met)) for met in self.metrics]
|
235 |
+
self.callbacks = sorted(self.callbacks, key=lambda o: getattr(o, '_order', 0))
|
236 |
+
self.smoothener = SmoothenValue(self.beta)
|
237 |
+
self.state_dict:Dict[str,Union[int,float,Tensor]]=_get_init_state()
|
238 |
+
|
239 |
+
def _call_and_update(self, cb, cb_name, **kwargs)->None:
|
240 |
+
"Call `cb_name` on `cb` and update the inner state."
|
241 |
+
new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
|
242 |
+
for k,v in new.items():
|
243 |
+
if k not in self.state_dict:
|
244 |
+
raise Exception(f"{k} isn't a valid key in the state of the callbacks.")
|
245 |
+
else: self.state_dict[k] = v
|
246 |
+
|
247 |
+
def __call__(self, cb_name, call_mets=True, **kwargs)->None:
|
248 |
+
"Call through to all of the `CallbakHandler` functions."
|
249 |
+
if call_mets:
|
250 |
+
for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
|
251 |
+
for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
|
252 |
+
|
253 |
+
def set_dl(self, dl:DataLoader):
|
254 |
+
"Set the current `dl` used."
|
255 |
+
if hasattr(self, 'cb_dl'): self.callbacks.remove(self.cb_dl)
|
256 |
+
if isinstance(dl.dataset, Callback):
|
257 |
+
self.callbacks.append(dl.dataset)
|
258 |
+
self.cb_dl = dl.dataset
|
259 |
+
|
260 |
+
def on_train_begin(self, epochs:int, pbar:PBar, metrics:MetricFuncList)->None:
|
261 |
+
"About to start learning."
|
262 |
+
self.state_dict = _get_init_state()
|
263 |
+
self.state_dict.update(dict(n_epochs=epochs, pbar=pbar, metrics=metrics))
|
264 |
+
names = [(met.name if hasattr(met, 'name') else camel2snake(met.__class__.__name__)) for met in self.metrics]
|
265 |
+
self('train_begin', metrics_names=names)
|
266 |
+
if self.state_dict['epoch'] != 0:
|
267 |
+
self.state_dict['pbar'].first_bar.total -= self.state_dict['epoch']
|
268 |
+
for cb in self.callbacks: cb.jump_to_epoch(self.state_dict['epoch'])
|
269 |
+
|
270 |
+
def on_epoch_begin(self)->None:
|
271 |
+
"Handle new epoch."
|
272 |
+
self.state_dict['num_batch'],self.state_dict['stop_training'] = 0,False
|
273 |
+
self('epoch_begin')
|
274 |
+
|
275 |
+
def on_batch_begin(self, xb:Tensor, yb:Tensor, train:bool=True)->Tuple[Any,Any]:
|
276 |
+
"Handle new batch `xb`,`yb` in `train` or validation."
|
277 |
+
self.state_dict.update(dict(last_input=xb, last_target=yb, train=train,
|
278 |
+
stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False))
|
279 |
+
self('batch_begin', mets = not self.state_dict['train'])
|
280 |
+
return self.state_dict['last_input'], self.state_dict['last_target']
|
281 |
+
|
282 |
+
def on_loss_begin(self, out:Tensor)->Any:
|
283 |
+
"Handle start of loss calculation with model output `out`."
|
284 |
+
self.state_dict['last_output'] = out
|
285 |
+
self('loss_begin', call_mets=False)
|
286 |
+
return self.state_dict['last_output']
|
287 |
+
|
288 |
+
def on_backward_begin(self, loss:Tensor)->Tuple[Any,Any]:
|
289 |
+
"Handle gradient calculation on `loss`."
|
290 |
+
self.smoothener.add_value(loss.detach().cpu())
|
291 |
+
self.state_dict['last_loss'], self.state_dict['smooth_loss'] = loss, self.smoothener.smooth
|
292 |
+
self('backward_begin', call_mets=False)
|
293 |
+
return self.state_dict['last_loss'], self.state_dict['skip_bwd']
|
294 |
+
|
295 |
+
def on_backward_end(self)->Any:
|
296 |
+
"Handle end of gradient calculation."
|
297 |
+
self('backward_end', call_mets=False)
|
298 |
+
return self.state_dict['skip_step']
|
299 |
+
|
300 |
+
def on_step_end(self)->Any:
|
301 |
+
"Handle end of optimization step."
|
302 |
+
self('step_end', call_mets=False)
|
303 |
+
return self.state_dict['skip_zero']
|
304 |
+
|
305 |
+
def on_batch_end(self, loss:Tensor)->Any:
|
306 |
+
"Handle end of processing one batch with `loss`."
|
307 |
+
self.state_dict['last_loss'] = loss
|
308 |
+
self('batch_end', call_mets = not self.state_dict['train'])
|
309 |
+
if self.state_dict['train']:
|
310 |
+
self.state_dict['iteration'] += 1
|
311 |
+
self.state_dict['num_batch'] += 1
|
312 |
+
return self.state_dict['stop_epoch']
|
313 |
+
|
314 |
+
def on_epoch_end(self, val_loss:Tensor)->bool:
|
315 |
+
"Epoch is done, process `val_loss`."
|
316 |
+
self.state_dict['last_metrics'] = [val_loss] if val_loss is not None else [None]
|
317 |
+
self('epoch_end', call_mets = val_loss is not None)
|
318 |
+
self.state_dict['epoch'] += 1
|
319 |
+
return self.state_dict['stop_training']
|
320 |
+
|
321 |
+
def on_train_end(self, exception:Union[bool,Exception])->None:
|
322 |
+
"Handle end of training, `exception` is an `Exception` or False if no exceptions during training."
|
323 |
+
self('train_end', exception=exception)
|
324 |
+
|
325 |
+
@property
|
326 |
+
def skip_validate(self): return self.state_dict['skip_validate']
|
327 |
+
|
328 |
+
class AverageMetric(Callback):
|
329 |
+
"Wrap a `func` in a callback for metrics computation."
|
330 |
+
def __init__(self, func):
|
331 |
+
# If func has a __name__ use this one else it should be a partial
|
332 |
+
name = func.__name__ if hasattr(func, '__name__') else func.func.__name__
|
333 |
+
self.func, self.name = func, name
|
334 |
+
self.world = num_distrib()
|
335 |
+
|
336 |
+
def on_epoch_begin(self, **kwargs):
|
337 |
+
"Set the inner value to 0."
|
338 |
+
self.val, self.count = 0.,0
|
339 |
+
|
340 |
+
def on_batch_end(self, last_output, last_target, **kwargs):
|
341 |
+
"Update metric computation with `last_output` and `last_target`."
|
342 |
+
if not is_listy(last_target): last_target=[last_target]
|
343 |
+
self.count += first_el(last_target).size(0)
|
344 |
+
val = self.func(last_output, *last_target)
|
345 |
+
if self.world:
|
346 |
+
val = val.clone()
|
347 |
+
dist.all_reduce(val, op=dist.ReduceOp.SUM)
|
348 |
+
val /= self.world
|
349 |
+
self.val += first_el(last_target).size(0) * val.detach().cpu()
|
350 |
+
|
351 |
+
def on_epoch_end(self, last_metrics, **kwargs):
|
352 |
+
"Set the final result in `last_metrics`."
|
353 |
+
return add_metrics(last_metrics, self.val/self.count)
|
354 |
+
|
355 |
+
def annealing_no(start:Number, end:Number, pct:float)->Number:
|
356 |
+
"No annealing, always return `start`."
|
357 |
+
return start
|
358 |
+
def annealing_linear(start:Number, end:Number, pct:float)->Number:
|
359 |
+
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
360 |
+
return start + pct * (end-start)
|
361 |
+
def annealing_exp(start:Number, end:Number, pct:float)->Number:
|
362 |
+
"Exponentially anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
363 |
+
return start * (end/start) ** pct
|
364 |
+
def annealing_cos(start:Number, end:Number, pct:float)->Number:
|
365 |
+
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
366 |
+
cos_out = np.cos(np.pi * pct) + 1
|
367 |
+
return end + (start-end)/2 * cos_out
|
368 |
+
|
369 |
+
def do_annealing_poly(start:Number, end:Number, pct:float, degree:Number)->Number:
|
370 |
+
"Helper function for `anneal_poly`."
|
371 |
+
return end + (start-end) * (1-pct)**degree
|
372 |
+
def annealing_poly(degree:Number)->Number:
|
373 |
+
"Anneal polynomically from `start` to `end` as pct goes from 0.0 to 1.0."
|
374 |
+
return functools.partial(do_annealing_poly, degree=degree)
|
375 |
+
|
376 |
+
class Scheduler():
|
377 |
+
"Used to \"step\" from start,end (`vals`) over `n_iter` iterations on a schedule defined by `func`"
|
378 |
+
def __init__(self, vals:StartOptEnd, n_iter:int, func:Optional[AnnealFunc]=None):
|
379 |
+
self.start,self.end = (vals[0],vals[1]) if is_tuple(vals) else (vals,0)
|
380 |
+
self.n_iter = max(1,n_iter)
|
381 |
+
if func is None: self.func = annealing_linear if is_tuple(vals) else annealing_no
|
382 |
+
else: self.func = func
|
383 |
+
self.n = 0
|
384 |
+
|
385 |
+
def restart(self): self.n = 0
|
386 |
+
|
387 |
+
def step(self)->Number:
|
388 |
+
"Return next value along annealed schedule."
|
389 |
+
self.n += 1
|
390 |
+
return self.func(self.start, self.end, self.n/self.n_iter)
|
391 |
+
|
392 |
+
@property
|
393 |
+
def is_done(self)->bool:
|
394 |
+
"Return `True` if schedule completed."
|
395 |
+
return self.n >= self.n_iter
|
396 |
+
|
DeOldify/fastai/callbacks/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .lr_finder import *
|
2 |
+
from .one_cycle import *
|
3 |
+
from .fp16 import *
|
4 |
+
from .general_sched import *
|
5 |
+
from .hooks import *
|
6 |
+
from .mixup import *
|
7 |
+
from .rnn import *
|
8 |
+
from .tracker import *
|
9 |
+
from .csv_logger import *
|
10 |
+
from .loss_metrics import *
|
11 |
+
from .oversampling import *
|
DeOldify/fastai/callbacks/csv_logger.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"A `Callback` that saves tracked metrics into a persistent file."
|
2 |
+
#Contribution from devforfu: https://nbviewer.jupyter.org/gist/devforfu/ea0b3fcfe194dad323c3762492b05cae
|
3 |
+
from ..torch_core import *
|
4 |
+
from ..basic_data import DataBunch
|
5 |
+
from ..callback import *
|
6 |
+
from ..basic_train import Learner, LearnerCallback
|
7 |
+
from time import time
|
8 |
+
from fastprogress.fastprogress import format_time
|
9 |
+
|
10 |
+
__all__ = ['CSVLogger']
|
11 |
+
|
12 |
+
class CSVLogger(LearnerCallback):
|
13 |
+
"A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."
|
14 |
+
def __init__(self, learn:Learner, filename: str = 'history', append: bool = False):
|
15 |
+
super().__init__(learn)
|
16 |
+
self.filename,self.path,self.append = filename,self.learn.path/f'{filename}.csv',append
|
17 |
+
self.add_time = True
|
18 |
+
|
19 |
+
def read_logged_file(self):
|
20 |
+
"Read the content of saved file"
|
21 |
+
return pd.read_csv(self.path)
|
22 |
+
|
23 |
+
def on_train_begin(self, **kwargs: Any) -> None:
|
24 |
+
"Prepare file with metric names."
|
25 |
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
26 |
+
self.file = self.path.open('a') if self.append else self.path.open('w')
|
27 |
+
self.file.write(','.join(self.learn.recorder.names[:(None if self.add_time else -1)]) + '\n')
|
28 |
+
|
29 |
+
def on_epoch_begin(self, **kwargs:Any)->None:
|
30 |
+
if self.add_time: self.start_epoch = time()
|
31 |
+
|
32 |
+
def on_epoch_end(self, epoch: int, smooth_loss: Tensor, last_metrics: MetricsList, **kwargs: Any) -> bool:
|
33 |
+
"Add a line with `epoch` number, `smooth_loss` and `last_metrics`."
|
34 |
+
last_metrics = ifnone(last_metrics, [])
|
35 |
+
stats = [str(stat) if isinstance(stat, int) else '#na#' if stat is None else f'{stat:.6f}'
|
36 |
+
for name, stat in zip(self.learn.recorder.names, [epoch, smooth_loss] + last_metrics)]
|
37 |
+
if self.add_time: stats.append(format_time(time() - self.start_epoch))
|
38 |
+
str_stats = ','.join(stats)
|
39 |
+
self.file.write(str_stats + '\n')
|
40 |
+
|
41 |
+
def on_train_end(self, **kwargs: Any) -> None:
|
42 |
+
"Close the file."
|
43 |
+
self.file.close()
|
DeOldify/fastai/callbacks/fp16.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Callback support for half precision (fp16) training. Increases training speed."
|
2 |
+
from ..torch_core import *
|
3 |
+
from ..callback import *
|
4 |
+
from ..basic_train import *
|
5 |
+
from torch._utils import _unflatten_dense_tensors
|
6 |
+
from torch.nn.utils import parameters_to_vector
|
7 |
+
|
8 |
+
__all__ = ['MixedPrecision']
|
9 |
+
|
10 |
+
def get_master(layer_groups:ModuleList, flat_master:bool=False) -> Tuple[List[List[Tensor]], List[List[Tensor]]]:
|
11 |
+
"Return two lists, one for the model parameters in FP16 and one for the master parameters in FP32."
|
12 |
+
split_params = split_no_wd_params(layer_groups)
|
13 |
+
model_params = [[param for param in pg if param.requires_grad] for pg in split_params]
|
14 |
+
if flat_master:
|
15 |
+
master_params = []
|
16 |
+
for lg in model_params:
|
17 |
+
if len(lg) !=0 :
|
18 |
+
mp = parameters_to_vector([param.data.float() for param in lg])
|
19 |
+
mp = torch.nn.Parameter(mp, requires_grad=True)
|
20 |
+
if mp.grad is None: mp.grad = mp.new(*mp.size())
|
21 |
+
master_params.append([mp])
|
22 |
+
else: master_params.append([])
|
23 |
+
return model_params, master_params
|
24 |
+
else:
|
25 |
+
master_params = [[param.clone().float().detach() for param in lg] for lg in model_params]
|
26 |
+
for mp in master_params:
|
27 |
+
for param in mp: param.requires_grad = True
|
28 |
+
return model_params, master_params
|
29 |
+
|
30 |
+
def model_g2master_g(model_params:Sequence[Tensor], master_params:Sequence[Tensor], flat_master:bool=False)->None:
|
31 |
+
"Copy the `model_params` gradients to `master_params` for the optimizer step."
|
32 |
+
if flat_master:
|
33 |
+
for model_group,master_group in zip(model_params,master_params):
|
34 |
+
if len(master_group) != 0:
|
35 |
+
if master_group[0].grad is None: master_group[0].grad = master_group[0].data.new(*master_group[0].data.size())
|
36 |
+
master_group[0].grad.data.copy_(parameters_to_vector([p.grad.data.float() for p in model_group]))
|
37 |
+
else:
|
38 |
+
for model_group,master_group in zip(model_params,master_params):
|
39 |
+
for model, master in zip(model_group, master_group):
|
40 |
+
if model.grad is not None:
|
41 |
+
if master.grad is None: master.grad = master.data.new(*master.data.size())
|
42 |
+
master.grad.data.copy_(model.grad.data)
|
43 |
+
else: master.grad = None
|
44 |
+
|
45 |
+
def master2model(model_params:Sequence[Tensor], master_params:Sequence[Tensor], flat_master:bool=False)->None:
|
46 |
+
"Copy `master_params` to `model_params`."
|
47 |
+
if flat_master:
|
48 |
+
for model_group,master_group in zip(model_params,master_params):
|
49 |
+
if len(model_group) != 0:
|
50 |
+
for model, master in zip(model_group, _unflatten_dense_tensors(master_group[0].data, model_group)):
|
51 |
+
model.data.copy_(master)
|
52 |
+
else:
|
53 |
+
for model_group,master_group in zip(model_params,master_params):
|
54 |
+
for model, master in zip(model_group, master_group): model.data.copy_(master.data)
|
55 |
+
|
56 |
+
def grad_overflow(param_group):
|
57 |
+
for group in param_group:
|
58 |
+
for p in group:
|
59 |
+
if p.grad is not None:
|
60 |
+
s = float(p.grad.data.float().sum())
|
61 |
+
if s == float('inf') or s == float('-inf') or s != s: return True
|
62 |
+
return False
|
63 |
+
|
64 |
+
class MixedPrecision(LearnerCallback):
|
65 |
+
_order = 999 #Need to run after things that could call on_backward_begin and change the loss
|
66 |
+
"Callback that handles mixed-precision training."
|
67 |
+
def __init__(self, learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None,
|
68 |
+
flat_master:bool=False, max_scale:float=2**24):
|
69 |
+
super().__init__(learn)
|
70 |
+
self.flat_master,self.dynamic,self.max_noskip,self.clip,self.max_scale = flat_master,dynamic,max_noskip,clip,max_scale
|
71 |
+
self.loss_scale = ifnone(loss_scale, 2**16 if dynamic else 512)
|
72 |
+
self.not_min += ['model_params', 'master_params']
|
73 |
+
assert torch.backends.cudnn.enabled, "Mixed precision training requires cudnn."
|
74 |
+
self.opt = None
|
75 |
+
|
76 |
+
def on_train_begin(self, **kwargs:Any)->None:
|
77 |
+
"Prepare the master model."
|
78 |
+
#Get a copy of the model params in FP32
|
79 |
+
self.model_params, self.master_params = get_master(self.learn.layer_groups, self.flat_master)
|
80 |
+
#Changes the optimizer so that the optimization step is done in FP32.
|
81 |
+
new_opt = self.learn.opt.new_with_params(self.master_params)
|
82 |
+
if self.opt is not None:
|
83 |
+
self.opt.lr,self.opt.wd = self.learn.opt.lr,self.learn.opt.wd
|
84 |
+
new_opt.load_state_dict(self.opt)
|
85 |
+
self.learn.opt.opt = new_opt.opt
|
86 |
+
self.noskip = 0
|
87 |
+
|
88 |
+
def on_loss_begin(self, last_output:Tensor, **kwargs:Any) -> Tensor:
|
89 |
+
"Convert half precision output to FP32 to avoid reduction overflow."
|
90 |
+
return {'last_output': to_float(last_output)}
|
91 |
+
|
92 |
+
def on_backward_begin(self, last_loss:Rank0Tensor, **kwargs:Any) -> Rank0Tensor:
|
93 |
+
"Scale gradients up by `self.loss_scale` to prevent underflow."
|
94 |
+
#To avoid gradient underflow, we scale the gradients
|
95 |
+
ret_loss = last_loss * self.loss_scale
|
96 |
+
return {'last_loss': ret_loss}
|
97 |
+
|
98 |
+
def on_backward_end(self, **kwargs:Any)->None:
|
99 |
+
"Convert the gradients back to FP32 and divide them by the scale."
|
100 |
+
if self.dynamic and grad_overflow(self.model_params) and self.loss_scale > 1:
|
101 |
+
self.loss_scale /= 2
|
102 |
+
self.noskip = 0
|
103 |
+
#The step will be skipped since we don't update the master grads so they are all None or zero
|
104 |
+
else:
|
105 |
+
model_g2master_g(self.model_params, self.master_params, self.flat_master)
|
106 |
+
for group in self.master_params:
|
107 |
+
for param in group:
|
108 |
+
if param.grad is not None: param.grad.div_(self.loss_scale)
|
109 |
+
if self.clip is not None:
|
110 |
+
for group in self.master_params: nn.utils.clip_grad_norm_(group, self.clip)
|
111 |
+
if not self.dynamic: return
|
112 |
+
self.noskip += 1
|
113 |
+
if self.noskip >= self.max_noskip and self.loss_scale < self.max_scale:
|
114 |
+
self.loss_scale *= 2
|
115 |
+
self.noskip = 0
|
116 |
+
|
117 |
+
def on_step_end(self, **kwargs:Any)->None:
|
118 |
+
"Update the params from master to model and zero grad."
|
119 |
+
#Zeros the gradients of the model since the optimizer is disconnected.
|
120 |
+
self.learn.model.zero_grad()
|
121 |
+
#Update the params from master to model.
|
122 |
+
master2model(self.model_params, self.master_params, self.flat_master)
|
DeOldify/fastai/callbacks/general_sched.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..core import *
|
2 |
+
from ..callback import *
|
3 |
+
from ..basic_train import Learner, LearnerCallback
|
4 |
+
|
5 |
+
__all__ = ['GeneralScheduler', 'TrainingPhase']
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class TrainingPhase():
|
9 |
+
"Schedule hyper-parameters for a phase of `length` iterations."
|
10 |
+
length:int
|
11 |
+
|
12 |
+
def __post_init__(self): self.scheds = dict()
|
13 |
+
def schedule_hp(self, name, vals, anneal=None):
|
14 |
+
"Adds a schedule for `name` between `vals` using `anneal`."
|
15 |
+
self.scheds[name] = Scheduler(vals, self.length, anneal)
|
16 |
+
return self
|
17 |
+
|
18 |
+
class GeneralScheduler(LearnerCallback):
|
19 |
+
"Schedule multiple `TrainingPhase` for a `Learner`."
|
20 |
+
def __init__(self, learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None):
|
21 |
+
super().__init__(learn)
|
22 |
+
self.phases,self.start_epoch = phases,start_epoch
|
23 |
+
|
24 |
+
def on_train_begin(self, epoch:int, **kwargs:Any)->None:
|
25 |
+
"Initialize the schedulers for training."
|
26 |
+
res = {'epoch':self.start_epoch} if self.start_epoch is not None else None
|
27 |
+
self.start_epoch = ifnone(self.start_epoch, epoch)
|
28 |
+
self.scheds = [p.scheds for p in self.phases]
|
29 |
+
self.opt = self.learn.opt
|
30 |
+
for k,v in self.scheds[0].items():
|
31 |
+
v.restart()
|
32 |
+
self.opt.set_stat(k, v.start)
|
33 |
+
self.idx_s = 0
|
34 |
+
return res
|
35 |
+
|
36 |
+
def jump_to_epoch(self, epoch:int)->None:
|
37 |
+
for _ in range(len(self.learn.data.train_dl) * epoch):
|
38 |
+
self.on_batch_end(True)
|
39 |
+
|
40 |
+
def on_batch_end(self, train, **kwargs:Any)->None:
|
41 |
+
"Take a step in lr,mom sched, start next stepper when the current one is complete."
|
42 |
+
if train:
|
43 |
+
if self.idx_s >= len(self.scheds): return {'stop_training': True, 'stop_epoch': True}
|
44 |
+
sched = self.scheds[self.idx_s]
|
45 |
+
for k,v in sched.items(): self.opt.set_stat(k, v.step())
|
46 |
+
if list(sched.values())[0].is_done: self.idx_s += 1
|
DeOldify/fastai/callbacks/hooks.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Hooks provide extensibility at the model level."
|
2 |
+
from ..torch_core import *
|
3 |
+
from ..callback import *
|
4 |
+
from ..basic_train import *
|
5 |
+
from ..basic_data import *
|
6 |
+
|
7 |
+
__all__ = ['ActivationStats', 'Hook', 'HookCallback', 'Hooks', 'hook_output', 'hook_outputs',
|
8 |
+
'model_sizes', 'num_features_model', 'model_summary', 'dummy_eval', 'dummy_batch']
|
9 |
+
|
10 |
+
class Hook():
|
11 |
+
"Create a hook on `m` with `hook_func`."
|
12 |
+
def __init__(self, m:nn.Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
|
13 |
+
self.hook_func,self.detach,self.stored = hook_func,detach,None
|
14 |
+
f = m.register_forward_hook if is_forward else m.register_backward_hook
|
15 |
+
self.hook = f(self.hook_fn)
|
16 |
+
self.removed = False
|
17 |
+
|
18 |
+
def hook_fn(self, module:nn.Module, input:Tensors, output:Tensors):
|
19 |
+
"Applies `hook_func` to `module`, `input`, `output`."
|
20 |
+
if self.detach:
|
21 |
+
input = (o.detach() for o in input ) if is_listy(input ) else input.detach()
|
22 |
+
output = (o.detach() for o in output) if is_listy(output) else output.detach()
|
23 |
+
self.stored = self.hook_func(module, input, output)
|
24 |
+
|
25 |
+
def remove(self):
|
26 |
+
"Remove the hook from the model."
|
27 |
+
if not self.removed:
|
28 |
+
self.hook.remove()
|
29 |
+
self.removed=True
|
30 |
+
|
31 |
+
def __enter__(self, *args): return self
|
32 |
+
def __exit__(self, *args): self.remove()
|
33 |
+
|
34 |
+
class Hooks():
|
35 |
+
"Create several hooks on the modules in `ms` with `hook_func`."
|
36 |
+
def __init__(self, ms:Collection[nn.Module], hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
|
37 |
+
self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]
|
38 |
+
|
39 |
+
def __getitem__(self,i:int)->Hook: return self.hooks[i]
|
40 |
+
def __len__(self)->int: return len(self.hooks)
|
41 |
+
def __iter__(self): return iter(self.hooks)
|
42 |
+
@property
|
43 |
+
def stored(self): return [o.stored for o in self]
|
44 |
+
|
45 |
+
def remove(self):
|
46 |
+
"Remove the hooks from the model."
|
47 |
+
for h in self.hooks: h.remove()
|
48 |
+
|
49 |
+
def __enter__(self, *args): return self
|
50 |
+
def __exit__ (self, *args): self.remove()
|
51 |
+
|
52 |
+
def _hook_inner(m,i,o): return o if isinstance(o,Tensor) else o if is_listy(o) else list(o)
|
53 |
+
|
54 |
+
def hook_output (module:nn.Module, detach:bool=True, grad:bool=False)->Hook:
|
55 |
+
"Return a `Hook` that stores activations of `module` in `self.stored`"
|
56 |
+
return Hook(module, _hook_inner, detach=detach, is_forward=not grad)
|
57 |
+
|
58 |
+
def hook_outputs(modules:Collection[nn.Module], detach:bool=True, grad:bool=False)->Hooks:
|
59 |
+
"Return `Hooks` that store activations of all `modules` in `self.stored`"
|
60 |
+
return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)
|
61 |
+
|
62 |
+
class HookCallback(LearnerCallback):
|
63 |
+
"Callback that can be used to register hooks on `modules`. Implement the corresponding function in `self.hook`."
|
64 |
+
def __init__(self, learn:Learner, modules:Sequence[nn.Module]=None, do_remove:bool=True):
|
65 |
+
super().__init__(learn)
|
66 |
+
self.modules,self.do_remove = modules,do_remove
|
67 |
+
|
68 |
+
def on_train_begin(self, **kwargs):
|
69 |
+
"Register the `Hooks` on `self.modules`."
|
70 |
+
if not self.modules:
|
71 |
+
self.modules = [m for m in flatten_model(self.learn.model)
|
72 |
+
if hasattr(m, 'weight')]
|
73 |
+
self.hooks = Hooks(self.modules, self.hook)
|
74 |
+
|
75 |
+
def on_train_end(self, **kwargs):
|
76 |
+
"Remove the `Hooks`."
|
77 |
+
if self.do_remove: self.remove()
|
78 |
+
|
79 |
+
def remove(self):
|
80 |
+
if getattr(self, 'hooks', None): self.hooks.remove()
|
81 |
+
def __del__(self): self.remove()
|
82 |
+
|
83 |
+
class ActivationStats(HookCallback):
|
84 |
+
"Callback that record the mean and std of activations."
|
85 |
+
|
86 |
+
def on_train_begin(self, **kwargs):
|
87 |
+
"Initialize stats."
|
88 |
+
super().on_train_begin(**kwargs)
|
89 |
+
self.stats = []
|
90 |
+
|
91 |
+
def hook(self, m:nn.Module, i:Tensors, o:Tensors)->Tuple[Rank0Tensor,Rank0Tensor]:
|
92 |
+
"Take the mean and std of `o`."
|
93 |
+
return o.mean().item(),o.std().item()
|
94 |
+
def on_batch_end(self, train, **kwargs):
|
95 |
+
"Take the stored results and puts it in `self.stats`"
|
96 |
+
if train: self.stats.append(self.hooks.stored)
|
97 |
+
def on_train_end(self, **kwargs):
|
98 |
+
"Polish the final result."
|
99 |
+
super().on_train_end(**kwargs)
|
100 |
+
self.stats = tensor(self.stats).permute(2,1,0)
|
101 |
+
|
102 |
+
def dummy_batch(m: nn.Module, size:tuple=(64,64))->Tensor:
|
103 |
+
"Create a dummy batch to go through `m` with `size`."
|
104 |
+
ch_in = in_channels(m)
|
105 |
+
return one_param(m).new(1, ch_in, *size).requires_grad_(False).uniform_(-1.,1.)
|
106 |
+
|
107 |
+
def dummy_eval(m:nn.Module, size:tuple=(64,64)):
|
108 |
+
"Pass a `dummy_batch` in evaluation mode in `m` with `size`."
|
109 |
+
m.eval()
|
110 |
+
return m(dummy_batch(m, size))
|
111 |
+
#return m.eval()(dummy_batch(m, size))
|
112 |
+
|
113 |
+
def model_sizes(m:nn.Module, size:tuple=(64,64))->Tuple[Sizes,Tensor,Hooks]:
|
114 |
+
"Pass a dummy input through the model `m` to get the various sizes of activations."
|
115 |
+
with hook_outputs(m) as hooks:
|
116 |
+
x = dummy_eval(m, size)
|
117 |
+
return [o.stored.shape for o in hooks]
|
118 |
+
|
119 |
+
def num_features_model(m:nn.Module)->int:
|
120 |
+
"Return the number of output features for `model`."
|
121 |
+
sz = 64
|
122 |
+
while True:
|
123 |
+
try: return model_sizes(m, size=(sz,sz))[-1][1]
|
124 |
+
except Exception as e:
|
125 |
+
sz *= 2
|
126 |
+
if sz > 2048: raise
|
127 |
+
|
128 |
+
def total_params(m:nn.Module)->int:
|
129 |
+
params, trainable = 0, False
|
130 |
+
if hasattr(m, "weight") and hasattr(m.weight, "size"):
|
131 |
+
params += m.weight.numel()
|
132 |
+
trainable = m.weight.requires_grad
|
133 |
+
if hasattr(m, "bias") and hasattr(m.bias, "size"): params += m.bias.numel()
|
134 |
+
return params, trainable
|
135 |
+
|
136 |
+
def hook_params(modules:Collection[nn.Module])->Hooks:
|
137 |
+
return Hooks(modules, lambda m, i, o: total_params(m))
|
138 |
+
|
139 |
+
def params_size(m: Union[nn.Module,Learner], size: tuple = (3, 64, 64))->Tuple[Sizes, Tensor, Hooks]:
|
140 |
+
"Pass a dummy input through the model to get the various sizes. Returns (res,x,hooks) if `full`"
|
141 |
+
if isinstance(m, Learner):
|
142 |
+
if m.data.is_empty:
|
143 |
+
raise Exception("This is an empty `Learner` and `Learner.summary` requires some data to pass through the model.")
|
144 |
+
ds_type = DatasetType.Train if m.data.train_dl else (DatasetType.Valid if m.data.valid_dl else DatasetType.Test)
|
145 |
+
x = m.data.one_batch(ds_type=ds_type, detach=False, denorm=False)[0]
|
146 |
+
x = [o[:1] for o in x] if is_listy(x) else x[:1]
|
147 |
+
m = m.model
|
148 |
+
elif isinstance(m, nn.Module): x = next(m.parameters()).new(1, *size)
|
149 |
+
else: raise TypeError('You should either pass in a Learner or nn.Module')
|
150 |
+
with hook_outputs(flatten_model(m)) as hook_o:
|
151 |
+
with hook_params(flatten_model(m))as hook_p:
|
152 |
+
x = m.eval()(*x) if is_listy(x) else m.eval()(x)
|
153 |
+
output_size = [((o.stored.shape[1:]) if o.stored is not None else None) for o in hook_o]
|
154 |
+
params = [(o.stored if o.stored is not None else (None,None)) for o in hook_p]
|
155 |
+
params, trainables = map(list,zip(*params))
|
156 |
+
return output_size, params, trainables
|
157 |
+
|
158 |
+
def get_layer_name(layer:nn.Module)->str:
|
159 |
+
return str(layer.__class__).split(".")[-1].split("'")[0]
|
160 |
+
|
161 |
+
def layers_info(m:Collection[nn.Module]) -> Collection[namedtuple]:
|
162 |
+
func = lambda m:list(map(get_layer_name, flatten_model(m)))
|
163 |
+
layers_names = func(m.model) if isinstance(m, Learner) else func(m)
|
164 |
+
layers_sizes, layers_params, layers_trainable = params_size(m)
|
165 |
+
layer_info = namedtuple('Layer_Information', ['Layer', 'OutputSize', 'Params', 'Trainable'])
|
166 |
+
return list(map(layer_info, layers_names, layers_sizes, layers_params, layers_trainable))
|
167 |
+
|
168 |
+
def model_summary(m:Learner, n:int=70):
|
169 |
+
"Print a summary of `m` using a output text width of `n` chars"
|
170 |
+
info = layers_info(m)
|
171 |
+
header = ["Layer (type)", "Output Shape", "Param #", "Trainable"]
|
172 |
+
res = m.model.__class__.__name__ + "\n"
|
173 |
+
res += "=" * n + "\n"
|
174 |
+
res += f"{header[0]:<20} {header[1]:<20} {header[2]:<10} {header[3]:<10}\n"
|
175 |
+
res += "=" * n + "\n"
|
176 |
+
total_params = 0
|
177 |
+
total_trainable_params = 0
|
178 |
+
for layer, size, params, trainable in info:
|
179 |
+
if size is None: continue
|
180 |
+
total_params += int(params)
|
181 |
+
total_trainable_params += int(params) * trainable
|
182 |
+
size, trainable = str(list(size)), str(trainable)
|
183 |
+
res += f"{layer:<20} {size:<20} {int(params):<10,} {trainable:<10}\n"
|
184 |
+
res += "_" * n + "\n"
|
185 |
+
res += f"\nTotal params: {total_params:,}\n"
|
186 |
+
res += f"Total trainable params: {total_trainable_params:,}\n"
|
187 |
+
res += f"Total non-trainable params: {total_params - total_trainable_params:,}\n"
|
188 |
+
|
189 |
+
res += f"Optimized with {str(m.opt_func)[25:-1].replace('>', '')}\n"
|
190 |
+
if m.true_wd: res += f"Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ \n"
|
191 |
+
if "wd" in str(m.opt_func) or "weight_decay" in str(m.opt_func): res += f"\x1b[1;31m Specifying weight decay in the optimizer has no effect, Learner will overwrite \x1b[0m \n"
|
192 |
+
if "lr" in str(m.opt_func) or "learning_rate" in str(m.opt_func): res += f"\x1b[1;31m Specifying lr in the optimizer has no effect, pass it to fit or the defaults.lr will apply \x1b[0m \n"
|
193 |
+
res += f"Loss function : {m.loss_func.__class__.__name__}\n"
|
194 |
+
res += "=" * n + "\n"
|
195 |
+
res += "Callbacks functions applied \n"
|
196 |
+
res += "\n".join([f" {cbs.__class__.__name__}" for cbs in m.callbacks])
|
197 |
+
|
198 |
+
return PrettyString(res)
|
199 |
+
|
200 |
+
Learner.summary = model_summary
|
DeOldify/fastai/callbacks/loss_metrics.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..torch_core import *
|
2 |
+
from ..callback import *
|
3 |
+
from ..basic_train import Learner, LearnerCallback
|
4 |
+
|
5 |
+
__all__ = ['LossMetrics']
|
6 |
+
|
7 |
+
class LossMetrics(LearnerCallback):
|
8 |
+
"Add `loss_func.metrics` to metrics named by `loss_func.metric_names`"
|
9 |
+
_order = -20 #Needs to run before the recorder
|
10 |
+
|
11 |
+
def on_train_begin(self, **kwargs):
|
12 |
+
"Add the metrics names to the `Recorder`."
|
13 |
+
self.names = ifnone(self.learn.loss_func.metric_names, [])
|
14 |
+
if not self.names: warn('LossMetrics requested but no loss_func.metric_names provided')
|
15 |
+
self.learn.recorder.add_metric_names(self.names)
|
16 |
+
|
17 |
+
def on_epoch_begin(self, **kwargs):
|
18 |
+
"Initialize the metrics for this epoch."
|
19 |
+
self.metrics = {name:0. for name in self.names}
|
20 |
+
self.nums = 0
|
21 |
+
|
22 |
+
def on_batch_end(self, last_target, train, **kwargs):
|
23 |
+
"Update the metrics if not `train`"
|
24 |
+
if train: return
|
25 |
+
bs = last_target.size(0)
|
26 |
+
for name in self.names:
|
27 |
+
self.metrics[name] += bs * self.learn.loss_func.metrics[name].detach().cpu()
|
28 |
+
self.nums += bs
|
29 |
+
|
30 |
+
def on_epoch_end(self, last_metrics, **kwargs):
|
31 |
+
"Finish the computation and sends the result to the Recorder."
|
32 |
+
if not self.nums: return
|
33 |
+
metrics = [self.metrics[name]/self.nums for name in self.names]
|
34 |
+
return {'last_metrics': last_metrics+metrics}
|
DeOldify/fastai/callbacks/lr_finder.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Tools to help find the optimal learning rate for training"
|
2 |
+
from ..torch_core import *
|
3 |
+
from ..basic_data import DataBunch
|
4 |
+
from ..callback import *
|
5 |
+
from ..basic_train import Learner, LearnerCallback
|
6 |
+
|
7 |
+
__all__ = ['LRFinder']
|
8 |
+
|
9 |
+
class LRFinder(LearnerCallback):
|
10 |
+
"Causes `learn` to go on a mock training from `start_lr` to `end_lr` for `num_it` iterations."
|
11 |
+
def __init__(self, learn:Learner, start_lr:float=1e-7, end_lr:float=10, num_it:int=100, stop_div:bool=True):
|
12 |
+
super().__init__(learn)
|
13 |
+
self.data,self.stop_div = learn.data,stop_div
|
14 |
+
self.sched = Scheduler((start_lr, end_lr), num_it, annealing_exp)
|
15 |
+
|
16 |
+
def on_train_begin(self, pbar, **kwargs:Any)->None:
|
17 |
+
"Initialize optimizer and learner hyperparameters."
|
18 |
+
setattr(pbar, 'clean_on_interrupt', True)
|
19 |
+
self.learn.save('tmp')
|
20 |
+
self.opt = self.learn.opt
|
21 |
+
self.opt.lr = self.sched.start
|
22 |
+
self.stop,self.best_loss = False,0.
|
23 |
+
return {'skip_validate': True}
|
24 |
+
|
25 |
+
def on_batch_end(self, iteration:int, smooth_loss:TensorOrNumber, **kwargs:Any)->None:
|
26 |
+
"Determine if loss has runaway and we should stop."
|
27 |
+
if iteration==0 or smooth_loss < self.best_loss: self.best_loss = smooth_loss
|
28 |
+
self.opt.lr = self.sched.step()
|
29 |
+
if self.sched.is_done or (self.stop_div and (smooth_loss > 4*self.best_loss or torch.isnan(smooth_loss))):
|
30 |
+
#We use the smoothed loss to decide on the stopping since it's less shaky.
|
31 |
+
return {'stop_epoch': True, 'stop_training': True}
|
32 |
+
|
33 |
+
def on_train_end(self, **kwargs:Any)->None:
|
34 |
+
"Cleanup learn model weights disturbed during LRFinder exploration."
|
35 |
+
self.learn.load('tmp', purge=False)
|
36 |
+
if hasattr(self.learn.model, 'reset'): self.learn.model.reset()
|
37 |
+
for cb in self.callbacks:
|
38 |
+
if hasattr(cb, 'reset'): cb.reset()
|
39 |
+
print('LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.')
|
DeOldify/fastai/callbacks/mem.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
" Memory profiling callbacks "
|
2 |
+
|
3 |
+
import tracemalloc, threading, torch, time
|
4 |
+
from ..utils.mem import *
|
5 |
+
from ..basic_train import *
|
6 |
+
from ..torch_core import *
|
7 |
+
from ..utils.pynvml_gate import *
|
8 |
+
|
9 |
+
if use_gpu: pynvml = load_pynvml_env()
|
10 |
+
|
11 |
+
class PeakMemMetric(LearnerCallback):
|
12 |
+
"Callback that measures used and peaked general and GPU memory."
|
13 |
+
|
14 |
+
_order=-20 # Needs to run before the recorder
|
15 |
+
|
16 |
+
def __init__(self, learn:Learner):
|
17 |
+
super().__init__(learn)
|
18 |
+
assert torch.cuda.is_available(), "pytorch CUDA is required"
|
19 |
+
preload_pytorch()
|
20 |
+
|
21 |
+
def peak_monitor_start(self):
|
22 |
+
self.peak_monitoring = True
|
23 |
+
|
24 |
+
# start RAM tracing
|
25 |
+
tracemalloc.start()
|
26 |
+
|
27 |
+
# this thread samples RAM usage as long as the current epoch of the fit loop is running
|
28 |
+
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
|
29 |
+
peak_monitor_thread.daemon = True
|
30 |
+
peak_monitor_thread.start()
|
31 |
+
|
32 |
+
def peak_monitor_stop(self):
|
33 |
+
tracemalloc.stop()
|
34 |
+
self.peak_monitoring = False
|
35 |
+
|
36 |
+
def peak_monitor_func(self):
|
37 |
+
self.gpu_mem_used_peak = -1
|
38 |
+
|
39 |
+
gpu_id = torch.cuda.current_device()
|
40 |
+
gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
41 |
+
|
42 |
+
while True:
|
43 |
+
gpu_mem_used = gpu_mem_get_used_fast(gpu_handle)
|
44 |
+
self.gpu_mem_used_peak = max(gpu_mem_used, self.gpu_mem_used_peak)
|
45 |
+
if not self.peak_monitoring: break
|
46 |
+
time.sleep(0.001) # 1msec
|
47 |
+
|
48 |
+
def on_train_begin(self, **kwargs):
|
49 |
+
self.learn.recorder.add_metric_names(['cpu used', 'peak', 'gpu used', 'peak'])
|
50 |
+
|
51 |
+
def on_epoch_begin(self, **kwargs):
|
52 |
+
self.peak_monitor_start()
|
53 |
+
self.gpu_before = gpu_mem_get_used_no_cache()
|
54 |
+
|
55 |
+
def on_epoch_end(self, last_metrics, **kwargs):
|
56 |
+
cpu_used, cpu_peak = list(map(lambda x: int(x/2**20), tracemalloc.get_traced_memory()))
|
57 |
+
self.peak_monitor_stop()
|
58 |
+
gpu_used = gpu_mem_get_used_no_cache() - self.gpu_before
|
59 |
+
gpu_peak = self.gpu_mem_used_peak - self.gpu_before
|
60 |
+
# can be negative, due to unreliable peak monitor thread
|
61 |
+
if gpu_peak < 0: gpu_peak = 0
|
62 |
+
# since we want the overhead only, subtract delta used if it's positive
|
63 |
+
elif gpu_used > 0: gpu_peak -= gpu_used
|
64 |
+
# The numbers are deltas in MBs (beginning of the epoch and the end)
|
65 |
+
return add_metrics(last_metrics, [cpu_used, cpu_peak, gpu_used, gpu_peak])
|
DeOldify/fastai/callbacks/misc.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
" Miscellaneous callbacks "
|
2 |
+
|
3 |
+
from fastai.callback import Callback
|
4 |
+
|
5 |
+
class StopAfterNBatches(Callback):
|
6 |
+
"Stop training after n batches of the first epoch."
|
7 |
+
def __init__(self, n_batches:int=2):
|
8 |
+
self.stop,self.n_batches = False,n_batches-1 # iteration starts from 0
|
9 |
+
|
10 |
+
def on_batch_end(self, iteration, **kwargs):
|
11 |
+
if iteration == self.n_batches:
|
12 |
+
return {'stop_epoch': True, 'stop_training': True, 'skip_validate': True}
|
DeOldify/fastai/callbacks/mixup.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Implements [mixup](https://arxiv.org/abs/1710.09412) training method"
|
2 |
+
from ..torch_core import *
|
3 |
+
from ..callback import *
|
4 |
+
from ..basic_train import Learner, LearnerCallback
|
5 |
+
|
6 |
+
class MixUpCallback(LearnerCallback):
|
7 |
+
"Callback that creates the mixed-up input and target."
|
8 |
+
def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True):
|
9 |
+
super().__init__(learn)
|
10 |
+
self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y
|
11 |
+
|
12 |
+
def on_train_begin(self, **kwargs):
|
13 |
+
if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func)
|
14 |
+
|
15 |
+
def on_batch_begin(self, last_input, last_target, train, **kwargs):
|
16 |
+
"Applies mixup to `last_input` and `last_target` if `train`."
|
17 |
+
if not train: return
|
18 |
+
lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
|
19 |
+
lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
|
20 |
+
lambd = last_input.new(lambd)
|
21 |
+
shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
|
22 |
+
x1, y1 = last_input[shuffle], last_target[shuffle]
|
23 |
+
if self.stack_x:
|
24 |
+
new_input = [last_input, last_input[shuffle], lambd]
|
25 |
+
else:
|
26 |
+
out_shape = [lambd.size(0)] + [1 for _ in range(len(x1.shape) - 1)]
|
27 |
+
new_input = (last_input * lambd.view(out_shape) + x1 * (1-lambd).view(out_shape))
|
28 |
+
if self.stack_y:
|
29 |
+
new_target = torch.cat([last_target[:,None].float(), y1[:,None].float(), lambd[:,None].float()], 1)
|
30 |
+
else:
|
31 |
+
if len(last_target.shape) == 2:
|
32 |
+
lambd = lambd.unsqueeze(1).float()
|
33 |
+
new_target = last_target.float() * lambd + y1.float() * (1-lambd)
|
34 |
+
return {'last_input': new_input, 'last_target': new_target}
|
35 |
+
|
36 |
+
def on_train_end(self, **kwargs):
|
37 |
+
if self.stack_y: self.learn.loss_func = self.learn.loss_func.get_old()
|
38 |
+
|
39 |
+
|
40 |
+
class MixUpLoss(Module):
|
41 |
+
"Adapt the loss function `crit` to go with mixup."
|
42 |
+
|
43 |
+
def __init__(self, crit, reduction='mean'):
|
44 |
+
super().__init__()
|
45 |
+
if hasattr(crit, 'reduction'):
|
46 |
+
self.crit = crit
|
47 |
+
self.old_red = crit.reduction
|
48 |
+
setattr(self.crit, 'reduction', 'none')
|
49 |
+
else:
|
50 |
+
self.crit = partial(crit, reduction='none')
|
51 |
+
self.old_crit = crit
|
52 |
+
self.reduction = reduction
|
53 |
+
|
54 |
+
def forward(self, output, target):
|
55 |
+
if len(target.size()) == 2:
|
56 |
+
loss1, loss2 = self.crit(output,target[:,0].long()), self.crit(output,target[:,1].long())
|
57 |
+
d = (loss1 * target[:,2] + loss2 * (1-target[:,2])).mean()
|
58 |
+
else: d = self.crit(output, target)
|
59 |
+
if self.reduction == 'mean': return d.mean()
|
60 |
+
elif self.reduction == 'sum': return d.sum()
|
61 |
+
return d
|
62 |
+
|
63 |
+
def get_old(self):
|
64 |
+
if hasattr(self, 'old_crit'): return self.old_crit
|
65 |
+
elif hasattr(self, 'old_red'):
|
66 |
+
setattr(self.crit, 'reduction', self.old_red)
|
67 |
+
return self.crit
|
DeOldify/fastai/callbacks/mlflow.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"A `Callback` that saves tracked metrics and notebook file into MLflow server."
|
2 |
+
from ..torch_core import *
|
3 |
+
from ..callback import *
|
4 |
+
from ..basic_train import Learner, LearnerCallback
|
5 |
+
#This is an optional dependency in fastai. Must install separately.
|
6 |
+
try: import mlflow
|
7 |
+
except: print("To use this tracker, please run 'pip install mlflow'")
|
8 |
+
|
9 |
+
class MLFlowTracker(LearnerCallback):
|
10 |
+
"A `TrackerCallback` that tracks the loss and metrics into MLFlow"
|
11 |
+
def __init__(self, learn:Learner, exp_name: str, params: dict, nb_path: str, uri: str = "http://localhost:5000"):
|
12 |
+
super().__init__(learn)
|
13 |
+
self.learn,self.exp_name,self.params,self.nb_path,self.uri = learn,exp_name,params,nb_path,uri
|
14 |
+
self.metrics_names = ['train_loss', 'valid_loss'] + [o.__name__ for o in learn.metrics]
|
15 |
+
|
16 |
+
def on_train_begin(self, **kwargs: Any) -> None:
|
17 |
+
"Prepare MLflow experiment and log params"
|
18 |
+
self.client = mlflow.tracking.MlflowClient(self.uri)
|
19 |
+
exp = self.client.get_experiment_by_name(self.exp_name)
|
20 |
+
self.exp_id = self.client.create_experiment(self.exp_name) if exp is None else exp.experiment_id
|
21 |
+
run = self.client.create_run(experiment_id=self.exp_id)
|
22 |
+
self.run = run.info.run_uuid
|
23 |
+
for k,v in self.params.items():
|
24 |
+
self.client.log_param(run_id=self.run, key=k, value=v)
|
25 |
+
|
26 |
+
def on_epoch_end(self, epoch, **kwargs:Any)->None:
|
27 |
+
"Send loss and metrics values to MLFlow after each epoch"
|
28 |
+
if kwargs['smooth_loss'] is None or kwargs["last_metrics"] is None: return
|
29 |
+
metrics = [kwargs['smooth_loss']] + kwargs["last_metrics"]
|
30 |
+
for name, val in zip(self.metrics_names, metrics):
|
31 |
+
self.client.log_metric(self.run, name, np.float(val), step=epoch)
|
32 |
+
|
33 |
+
def on_train_end(self, **kwargs: Any) -> None:
|
34 |
+
"Store the notebook and stop run"
|
35 |
+
self.client.log_artifact(run_id=self.run, local_path=self.nb_path)
|
36 |
+
self.client.set_terminated(run_id=self.run)
|
DeOldify/fastai/callbacks/one_cycle.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Supports 1-Cycle style training"
|
2 |
+
from ..core import *
|
3 |
+
from ..callback import *
|
4 |
+
from ..basic_train import Learner,LearnerCallback
|
5 |
+
|
6 |
+
__all__ = ['OneCycleScheduler']
|
7 |
+
|
8 |
+
class OneCycleScheduler(LearnerCallback):
|
9 |
+
"Manage 1-Cycle style training as outlined in Leslie Smith's [paper](https://arxiv.org/pdf/1803.09820.pdf)."
|
10 |
+
def __init__(self, learn:Learner, lr_max:float, moms:Floats=(0.95,0.85), div_factor:float=25., pct_start:float=0.3,
|
11 |
+
final_div:float=None, tot_epochs:int=None, start_epoch:int=None):
|
12 |
+
super().__init__(learn)
|
13 |
+
self.lr_max,self.div_factor,self.pct_start,self.final_div = lr_max,div_factor,pct_start,final_div
|
14 |
+
if self.final_div is None: self.final_div = div_factor*1e4
|
15 |
+
self.moms=tuple(listify(moms,2))
|
16 |
+
if is_listy(self.lr_max): self.lr_max = np.array(self.lr_max)
|
17 |
+
self.start_epoch, self.tot_epochs = start_epoch, tot_epochs
|
18 |
+
|
19 |
+
def steps(self, *steps_cfg:StartOptEnd):
|
20 |
+
"Build anneal schedule for all of the parameters."
|
21 |
+
return [Scheduler(step, n_iter, func=func)
|
22 |
+
for (step,(n_iter,func)) in zip(steps_cfg, self.phases)]
|
23 |
+
|
24 |
+
def on_train_begin(self, n_epochs:int, epoch:int, **kwargs:Any)->None:
|
25 |
+
"Initialize our optimization params based on our annealing schedule."
|
26 |
+
res = {'epoch':self.start_epoch} if self.start_epoch is not None else None
|
27 |
+
self.start_epoch = ifnone(self.start_epoch, epoch)
|
28 |
+
self.tot_epochs = ifnone(self.tot_epochs, n_epochs)
|
29 |
+
n = len(self.learn.data.train_dl) * self.tot_epochs
|
30 |
+
a1 = int(n * self.pct_start)
|
31 |
+
a2 = n-a1
|
32 |
+
self.phases = ((a1, annealing_cos), (a2, annealing_cos))
|
33 |
+
low_lr = self.lr_max/self.div_factor
|
34 |
+
self.lr_scheds = self.steps((low_lr, self.lr_max), (self.lr_max, self.lr_max/self.final_div))
|
35 |
+
self.mom_scheds = self.steps(self.moms, (self.moms[1], self.moms[0]))
|
36 |
+
self.opt = self.learn.opt
|
37 |
+
self.opt.lr,self.opt.mom = self.lr_scheds[0].start,self.mom_scheds[0].start
|
38 |
+
self.idx_s = 0
|
39 |
+
return res
|
40 |
+
|
41 |
+
def jump_to_epoch(self, epoch:int)->None:
|
42 |
+
for _ in range(len(self.learn.data.train_dl) * epoch):
|
43 |
+
self.on_batch_end(True)
|
44 |
+
|
45 |
+
def on_batch_end(self, train, **kwargs:Any)->None:
|
46 |
+
"Take one step forward on the annealing schedule for the optim params."
|
47 |
+
if train:
|
48 |
+
if self.idx_s >= len(self.lr_scheds): return {'stop_training': True, 'stop_epoch': True}
|
49 |
+
self.opt.lr = self.lr_scheds[self.idx_s].step()
|
50 |
+
self.opt.mom = self.mom_scheds[self.idx_s].step()
|
51 |
+
# when the current schedule is complete we move onto the next
|
52 |
+
# schedule. (in 1-cycle there are two schedules)
|
53 |
+
if self.lr_scheds[self.idx_s].is_done:
|
54 |
+
self.idx_s += 1
|
55 |
+
|
56 |
+
def on_epoch_end(self, epoch, **kwargs:Any)->None:
|
57 |
+
"Tell Learner to stop if the cycle is finished."
|
58 |
+
if epoch > self.tot_epochs: return {'stop_training': True}
|
DeOldify/fastai/callbacks/oversampling.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..torch_core import *
|
2 |
+
from ..basic_data import DataBunch
|
3 |
+
from ..callback import *
|
4 |
+
from ..basic_train import Learner,LearnerCallback
|
5 |
+
from torch.utils.data.sampler import WeightedRandomSampler
|
6 |
+
|
7 |
+
__all__ = ['OverSamplingCallback']
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class OverSamplingCallback(LearnerCallback):
|
12 |
+
def __init__(self,learn:Learner,weights:torch.Tensor=None):
|
13 |
+
super().__init__(learn)
|
14 |
+
self.labels = self.learn.data.train_dl.dataset.y.items
|
15 |
+
_, counts = np.unique(self.labels,return_counts=True)
|
16 |
+
self.weights = (weights if weights is not None else
|
17 |
+
torch.DoubleTensor((1/counts)[self.labels]))
|
18 |
+
self.label_counts = np.bincount([self.learn.data.train_dl.dataset.y[i].data for i in range(len(self.learn.data.train_dl.dataset))])
|
19 |
+
self.total_len_oversample = int(self.learn.data.c*np.max(self.label_counts))
|
20 |
+
|
21 |
+
def on_train_begin(self, **kwargs):
|
22 |
+
self.learn.data.train_dl.dl.batch_sampler = BatchSampler(WeightedRandomSampler(self.weights,self.total_len_oversample), self.learn.data.train_dl.batch_size,False)
|
DeOldify/fastai/callbacks/rnn.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Regroups lr adjustment to seq_len, AR and TAR"
|
2 |
+
from ..torch_core import *
|
3 |
+
from ..callback import *
|
4 |
+
from ..basic_train import Learner, LearnerCallback
|
5 |
+
|
6 |
+
__all__ = ['RNNTrainer']
|
7 |
+
|
8 |
+
class RNNTrainer(LearnerCallback):
|
9 |
+
"`Callback` that regroups lr adjustment to seq_len, AR and TAR."
|
10 |
+
def __init__(self, learn:Learner, alpha:float=0., beta:float=0.):
|
11 |
+
super().__init__(learn)
|
12 |
+
self.not_min += ['raw_out', 'out']
|
13 |
+
self.alpha,self.beta = alpha,beta
|
14 |
+
|
15 |
+
def on_epoch_begin(self, **kwargs):
|
16 |
+
"Reset the hidden state of the model."
|
17 |
+
self.learn.model.reset()
|
18 |
+
|
19 |
+
def on_loss_begin(self, last_output:Tuple[Tensor,Tensor,Tensor], **kwargs):
|
20 |
+
"Save the extra outputs for later and only returns the true output."
|
21 |
+
self.raw_out,self.out = last_output[1],last_output[2]
|
22 |
+
return {'last_output': last_output[0]}
|
23 |
+
|
24 |
+
def on_backward_begin(self, last_loss:Rank0Tensor, last_input:Tensor, **kwargs):
|
25 |
+
"Apply AR and TAR to `last_loss`."
|
26 |
+
#AR and TAR
|
27 |
+
if self.alpha != 0.: last_loss += self.alpha * self.out[-1].float().pow(2).mean()
|
28 |
+
if self.beta != 0.:
|
29 |
+
h = self.raw_out[-1]
|
30 |
+
if len(h)>1: last_loss += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()
|
31 |
+
return {'last_loss': last_loss}
|
DeOldify/fastai/callbacks/tensorboard.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Provides convenient callbacks for Learners that write model images, metrics/losses, stats and histograms to Tensorboard"
|
2 |
+
from ..basic_train import Learner
|
3 |
+
from ..basic_data import DatasetType, DataBunch
|
4 |
+
from ..vision import Image
|
5 |
+
from ..vision.gan import GANLearner
|
6 |
+
from ..callbacks import LearnerCallback
|
7 |
+
from ..core import *
|
8 |
+
from ..torch_core import *
|
9 |
+
from threading import Thread, Event
|
10 |
+
from time import sleep
|
11 |
+
from queue import Queue
|
12 |
+
import statistics
|
13 |
+
import torchvision.utils as vutils
|
14 |
+
from abc import ABC
|
15 |
+
#This is an optional dependency in fastai. Must install separately.
|
16 |
+
try: from tensorboardX import SummaryWriter
|
17 |
+
except: print("To use this tracker, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results")
|
18 |
+
|
19 |
+
__all__=['LearnerTensorboardWriter', 'GANTensorboardWriter', 'ImageGenTensorboardWriter']
|
20 |
+
|
21 |
+
#---Example usage (applies to any of the callbacks)---
|
22 |
+
# proj_id = 'Colorize'
|
23 |
+
# tboard_path = Path('data/tensorboard/' + proj_id)
|
24 |
+
# learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=tboard_path, name='GanLearner'))
|
25 |
+
|
26 |
+
class LearnerTensorboardWriter(LearnerCallback):
|
27 |
+
"Broadly useful callback for Learners that writes to Tensorboard. Writes model histograms, losses/metrics, and gradient stats."
|
28 |
+
def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
|
29 |
+
super().__init__(learn=learn)
|
30 |
+
self.base_dir,self.name,self.loss_iters,self.hist_iters,self.stats_iters = base_dir,name,loss_iters,hist_iters,stats_iters
|
31 |
+
log_dir = base_dir/name
|
32 |
+
self.tbwriter = SummaryWriter(str(log_dir))
|
33 |
+
self.hist_writer = HistogramTBWriter()
|
34 |
+
self.stats_writer = ModelStatsTBWriter()
|
35 |
+
#self.graph_writer = GraphTBWriter()
|
36 |
+
self.data = None
|
37 |
+
self.metrics_root = '/metrics/'
|
38 |
+
self._update_batches_if_needed()
|
39 |
+
|
40 |
+
def _get_new_batch(self, ds_type:DatasetType)->Collection[Tensor]:
|
41 |
+
"Retrieves new batch of DatasetType, and detaches it."
|
42 |
+
return self.learn.data.one_batch(ds_type=ds_type, detach=True, denorm=False, cpu=False)
|
43 |
+
|
44 |
+
def _update_batches_if_needed(self)->None:
|
45 |
+
"one_batch function is extremely slow with large datasets. This is caching the result as an optimization."
|
46 |
+
if self.learn.data.valid_dl is None: return # Running learning rate finder, so return
|
47 |
+
update_batches = self.data is not self.learn.data
|
48 |
+
if not update_batches: return
|
49 |
+
self.data = self.learn.data
|
50 |
+
self.trn_batch = self._get_new_batch(ds_type=DatasetType.Train)
|
51 |
+
self.val_batch = self._get_new_batch(ds_type=DatasetType.Valid)
|
52 |
+
|
53 |
+
def _write_model_stats(self, iteration:int)->None:
|
54 |
+
"Writes gradient statistics to Tensorboard."
|
55 |
+
self.stats_writer.write(model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
|
56 |
+
|
57 |
+
def _write_training_loss(self, iteration:int, last_loss:Tensor)->None:
|
58 |
+
"Writes training loss to Tensorboard."
|
59 |
+
scalar_value = to_np(last_loss)
|
60 |
+
tag = self.metrics_root + 'train_loss'
|
61 |
+
self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
62 |
+
|
63 |
+
def _write_weight_histograms(self, iteration:int)->None:
|
64 |
+
"Writes model weight histograms to Tensorboard."
|
65 |
+
self.hist_writer.write(model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
|
66 |
+
|
67 |
+
def _write_scalar(self, name:str, scalar_value, iteration:int)->None:
|
68 |
+
"Writes single scalar value to Tensorboard."
|
69 |
+
tag = self.metrics_root + name
|
70 |
+
self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
71 |
+
|
72 |
+
#TODO: Relying on a specific hardcoded start_idx here isn't great. Is there a better solution?
|
73 |
+
def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2)->None:
|
74 |
+
"Writes training metrics to Tensorboard."
|
75 |
+
recorder = self.learn.recorder
|
76 |
+
for i, name in enumerate(recorder.names[start_idx:]):
|
77 |
+
if last_metrics is None or len(last_metrics) < i+1: return
|
78 |
+
scalar_value = last_metrics[i]
|
79 |
+
self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration)
|
80 |
+
|
81 |
+
def on_train_begin(self, **kwargs: Any) -> None:
|
82 |
+
#self.graph_writer.write(model=self.learn.model, tbwriter=self.tbwriter,
|
83 |
+
#input_to_model=next(iter(self.learn.data.dl(DatasetType.Single)))[0])
|
84 |
+
return
|
85 |
+
|
86 |
+
def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs)->None:
|
87 |
+
"Callback function that writes batch end appropriate data to Tensorboard."
|
88 |
+
if iteration == 0: return
|
89 |
+
self._update_batches_if_needed()
|
90 |
+
if iteration % self.loss_iters == 0: self._write_training_loss(iteration=iteration, last_loss=last_loss)
|
91 |
+
if iteration % self.hist_iters == 0: self._write_weight_histograms(iteration=iteration)
|
92 |
+
|
93 |
+
# Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
|
94 |
+
def on_backward_end(self, iteration:int, **kwargs)->None:
|
95 |
+
"Callback function that writes backward end appropriate data to Tensorboard."
|
96 |
+
if iteration == 0: return
|
97 |
+
self._update_batches_if_needed()
|
98 |
+
if iteration % self.stats_iters == 0: self._write_model_stats(iteration=iteration)
|
99 |
+
|
100 |
+
def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None:
|
101 |
+
"Callback function that writes epoch end appropriate data to Tensorboard."
|
102 |
+
self._write_metrics(iteration=iteration, last_metrics=last_metrics)
|
103 |
+
|
104 |
+
# TODO: We're overriding almost everything here. Seems like a good idea to question that ("is a" vs "has a")
|
105 |
+
class GANTensorboardWriter(LearnerTensorboardWriter):
|
106 |
+
"Callback for GANLearners that writes to Tensorboard. Extends LearnerTensorboardWriter and adds output image writes."
|
107 |
+
def __init__(self, learn:GANLearner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
|
108 |
+
stats_iters:int=100, visual_iters:int=100):
|
109 |
+
super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters, stats_iters=stats_iters)
|
110 |
+
self.visual_iters = visual_iters
|
111 |
+
self.img_gen_vis = ImageTBWriter()
|
112 |
+
self.gen_stats_updated = True
|
113 |
+
self.crit_stats_updated = True
|
114 |
+
|
115 |
+
def _write_weight_histograms(self, iteration:int)->None:
|
116 |
+
"Writes model weight histograms to Tensorboard."
|
117 |
+
generator, critic = self.learn.gan_trainer.generator, self.learn.gan_trainer.critic
|
118 |
+
self.hist_writer.write(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
|
119 |
+
self.hist_writer.write(model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
|
120 |
+
|
121 |
+
def _write_gen_model_stats(self, iteration:int)->None:
|
122 |
+
"Writes gradient statistics for generator to Tensorboard."
|
123 |
+
generator = self.learn.gan_trainer.generator
|
124 |
+
self.stats_writer.write(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
|
125 |
+
self.gen_stats_updated = True
|
126 |
+
|
127 |
+
def _write_critic_model_stats(self, iteration:int)->None:
|
128 |
+
"Writes gradient statistics for critic to Tensorboard."
|
129 |
+
critic = self.learn.gan_trainer.critic
|
130 |
+
self.stats_writer.write(model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
|
131 |
+
self.crit_stats_updated = True
|
132 |
+
|
133 |
+
def _write_model_stats(self, iteration:int)->None:
|
134 |
+
"Writes gradient statistics to Tensorboard."
|
135 |
+
# We don't want to write stats when model is not iterated on and hence has zeroed out gradients
|
136 |
+
gen_mode = self.learn.gan_trainer.gen_mode
|
137 |
+
if gen_mode and not self.gen_stats_updated: self._write_gen_model_stats(iteration=iteration)
|
138 |
+
if not gen_mode and not self.crit_stats_updated: self._write_critic_model_stats(iteration=iteration)
|
139 |
+
|
140 |
+
def _write_training_loss(self, iteration:int, last_loss:Tensor)->None:
|
141 |
+
"Writes training loss to Tensorboard."
|
142 |
+
recorder = self.learn.gan_trainer.recorder
|
143 |
+
if len(recorder.losses) == 0: return
|
144 |
+
scalar_value = to_np((recorder.losses[-1:])[0])
|
145 |
+
tag = self.metrics_root + 'train_loss'
|
146 |
+
self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
147 |
+
|
148 |
+
def _write_images(self, iteration:int)->None:
|
149 |
+
"Writes model generated, original and real images to Tensorboard."
|
150 |
+
trainer = self.learn.gan_trainer
|
151 |
+
#TODO: Switching gen_mode temporarily seems a bit hacky here. Certainly not a good side-effect. Is there a better way?
|
152 |
+
gen_mode = trainer.gen_mode
|
153 |
+
try:
|
154 |
+
trainer.switch(gen_mode=True)
|
155 |
+
self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
|
156 |
+
iteration=iteration, tbwriter=self.tbwriter)
|
157 |
+
finally: trainer.switch(gen_mode=gen_mode)
|
158 |
+
|
159 |
+
def on_batch_end(self, iteration:int, **kwargs)->None:
|
160 |
+
"Callback function that writes batch end appropriate data to Tensorboard."
|
161 |
+
super().on_batch_end(iteration=iteration, **kwargs)
|
162 |
+
if iteration == 0: return
|
163 |
+
if iteration % self.visual_iters == 0: self._write_images(iteration=iteration)
|
164 |
+
|
165 |
+
def on_backward_end(self, iteration:int, **kwargs)->None:
|
166 |
+
"Callback function that writes backward end appropriate data to Tensorboard."
|
167 |
+
if iteration == 0: return
|
168 |
+
self._update_batches_if_needed()
|
169 |
+
#TODO: This could perhaps be implemented as queues of requests instead but that seemed like overkill.
|
170 |
+
# But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
|
171 |
+
if iteration % self.stats_iters == 0: self.gen_stats_updated, self.crit_stats_updated = False, False
|
172 |
+
if not (self.gen_stats_updated and self.crit_stats_updated): self._write_model_stats(iteration=iteration)
|
173 |
+
|
174 |
+
class ImageGenTensorboardWriter(LearnerTensorboardWriter):
|
175 |
+
"Callback for non-GAN image generating Learners that writes to Tensorboard. Extends LearnerTensorboardWriter and adds output image writes."
|
176 |
+
def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100,
|
177 |
+
visual_iters:int=100):
|
178 |
+
super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters,
|
179 |
+
stats_iters=stats_iters)
|
180 |
+
self.visual_iters = visual_iters
|
181 |
+
self.img_gen_vis = ImageTBWriter()
|
182 |
+
|
183 |
+
def _write_images(self, iteration:int)->None:
|
184 |
+
"Writes model generated, original and real images to Tensorboard"
|
185 |
+
self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, iteration=iteration,
|
186 |
+
tbwriter=self.tbwriter)
|
187 |
+
|
188 |
+
def on_batch_end(self, iteration:int, **kwargs)->None:
|
189 |
+
"Callback function that writes batch end appropriate data to Tensorboard."
|
190 |
+
super().on_batch_end(iteration=iteration, **kwargs)
|
191 |
+
if iteration == 0: return
|
192 |
+
if iteration % self.visual_iters == 0:
|
193 |
+
self._write_images(iteration=iteration)
|
194 |
+
|
195 |
+
class TBWriteRequest(ABC):
|
196 |
+
"A request object for Tensorboard writes. Useful for queuing up and executing asynchronous writes."
|
197 |
+
def __init__(self, tbwriter: SummaryWriter, iteration:int):
|
198 |
+
super().__init__()
|
199 |
+
self.tbwriter = tbwriter
|
200 |
+
self.iteration = iteration
|
201 |
+
|
202 |
+
@abstractmethod
|
203 |
+
def write(self)->None: pass
|
204 |
+
|
205 |
+
# SummaryWriter writes tend to block quite a bit. This gets around that and greatly boosts performance.
|
206 |
+
# Not all tensorboard writes are using this- just the ones that take a long time. Note that the
|
207 |
+
# SummaryWriter does actually use a threadsafe consumer/producer design ultimately to write to Tensorboard,
|
208 |
+
# so writes done outside of this async loop should be fine.
|
209 |
+
class AsyncTBWriter():
|
210 |
+
"Callback for GANLearners that writes to Tensorboard. Extends LearnerTensorboardWriter and adds output image writes."
|
211 |
+
def __init__(self):
|
212 |
+
super().__init__()
|
213 |
+
self.stop_request = Event()
|
214 |
+
self.queue = Queue()
|
215 |
+
self.thread = Thread(target=self._queue_processor, daemon=True)
|
216 |
+
self.thread.start()
|
217 |
+
|
218 |
+
def request_write(self, request: TBWriteRequest)->None:
|
219 |
+
"Queues up an asynchronous write request to Tensorboard."
|
220 |
+
if self.stop_request.isSet(): return
|
221 |
+
self.queue.put(request)
|
222 |
+
|
223 |
+
def _queue_processor(self)->None:
|
224 |
+
"Processes queued up write requests asynchronously to Tensorboard."
|
225 |
+
while not self.stop_request.isSet():
|
226 |
+
while not self.queue.empty():
|
227 |
+
if self.stop_request.isSet(): return
|
228 |
+
request = self.queue.get()
|
229 |
+
request.write()
|
230 |
+
sleep(0.2)
|
231 |
+
|
232 |
+
#Provided this to stop thread explicitly or by context management (with statement) but thread should end on its own
|
233 |
+
# upon program exit, due to being a daemon. So using this is probably unecessary.
|
234 |
+
def close(self)->None:
|
235 |
+
"Stops asynchronous request queue processing thread."
|
236 |
+
self.stop_request.set()
|
237 |
+
self.thread.join()
|
238 |
+
|
239 |
+
# Nothing to do, thread already started. Could start thread here to enforce use of context manager
|
240 |
+
# (but that sounds like a pain and a bit unweildy and unecessary for actual usage)
|
241 |
+
def __enter__(self): pass
|
242 |
+
|
243 |
+
def __exit__(self, exc_type, exc_value, traceback): self.close()
|
244 |
+
|
245 |
+
asyncTBWriter = AsyncTBWriter()
|
246 |
+
|
247 |
+
class ModelImageSet():
|
248 |
+
"Convenience object that holds the original, real(target) and generated versions of a single image fed to a model."
|
249 |
+
@staticmethod
|
250 |
+
def get_list_from_model(learn:Learner, ds_type:DatasetType, batch:Tuple)->[]:
|
251 |
+
"Factory method to convert a batch of model images to a list of ModelImageSet."
|
252 |
+
image_sets = []
|
253 |
+
x,y = batch[0],batch[1]
|
254 |
+
preds=[]
|
255 |
+
preds = learn.pred_batch(ds_type=ds_type, batch=(x,y), reconstruct=True)
|
256 |
+
for orig_px, real_px, gen in zip(x,y,preds):
|
257 |
+
orig, real = Image(px=orig_px), Image(px=real_px)
|
258 |
+
image_set = ModelImageSet(orig=orig, real=real, gen=gen)
|
259 |
+
image_sets.append(image_set)
|
260 |
+
return image_sets
|
261 |
+
|
262 |
+
def __init__(self, orig:Image, real:Image, gen:Image): self.orig, self.real, self.gen = orig, real, gen
|
263 |
+
|
264 |
+
class HistogramTBRequest(TBWriteRequest):
|
265 |
+
"Request object for model histogram writes to Tensorboard."
|
266 |
+
def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
|
267 |
+
super().__init__(tbwriter=tbwriter, iteration=iteration)
|
268 |
+
self.params = [(name, values.clone().detach().cpu()) for (name, values) in model.named_parameters()]
|
269 |
+
self.name = name
|
270 |
+
|
271 |
+
def _write_histogram(self, param_name:str, values)->None:
|
272 |
+
"Writes single model histogram to Tensorboard."
|
273 |
+
tag = self.name + '/weights/' + param_name
|
274 |
+
self.tbwriter.add_histogram(tag=tag, values=values, global_step=self.iteration)
|
275 |
+
|
276 |
+
def write(self)->None:
|
277 |
+
"Writes model histograms to Tensorboard."
|
278 |
+
for param_name, values in self.params: self._write_histogram(param_name=param_name, values=values)
|
279 |
+
|
280 |
+
#If this isn't done async then this is sloooooow
|
281 |
+
class HistogramTBWriter():
|
282 |
+
"Writes model histograms to Tensorboard."
|
283 |
+
def __init__(self): super().__init__()
|
284 |
+
|
285 |
+
def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model')->None:
|
286 |
+
"Writes model histograms to Tensorboard."
|
287 |
+
request = HistogramTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
|
288 |
+
asyncTBWriter.request_write(request)
|
289 |
+
|
290 |
+
class ModelStatsTBRequest(TBWriteRequest):
|
291 |
+
"Request object for model gradient statistics writes to Tensorboard."
|
292 |
+
def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
|
293 |
+
super().__init__(tbwriter=tbwriter, iteration=iteration)
|
294 |
+
self.gradients = [x.grad.clone().detach().cpu() for x in model.parameters() if x.grad is not None]
|
295 |
+
self.name = name
|
296 |
+
|
297 |
+
def _add_gradient_scalar(self, name:str, scalar_value)->None:
|
298 |
+
"Writes a single scalar value for a gradient statistic to Tensorboard."
|
299 |
+
tag = self.name + '/gradients/' + name
|
300 |
+
self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=self.iteration)
|
301 |
+
|
302 |
+
def _write_avg_norm(self, norms:[])->None:
|
303 |
+
"Writes the average norm of the gradients to Tensorboard."
|
304 |
+
avg_norm = sum(norms)/len(self.gradients)
|
305 |
+
self._add_gradient_scalar('avg_norm', scalar_value=avg_norm)
|
306 |
+
|
307 |
+
def _write_median_norm(self, norms:[])->None:
|
308 |
+
"Writes the median norm of the gradients to Tensorboard."
|
309 |
+
median_norm = statistics.median(norms)
|
310 |
+
self._add_gradient_scalar('median_norm', scalar_value=median_norm)
|
311 |
+
|
312 |
+
def _write_max_norm(self, norms:[])->None:
|
313 |
+
"Writes the maximum norm of the gradients to Tensorboard."
|
314 |
+
max_norm = max(norms)
|
315 |
+
self._add_gradient_scalar('max_norm', scalar_value=max_norm)
|
316 |
+
|
317 |
+
def _write_min_norm(self, norms:[])->None:
|
318 |
+
"Writes the minimum norm of the gradients to Tensorboard."
|
319 |
+
min_norm = min(norms)
|
320 |
+
self._add_gradient_scalar('min_norm', scalar_value=min_norm)
|
321 |
+
|
322 |
+
def _write_num_zeros(self)->None:
|
323 |
+
"Writes the number of zeroes in the gradients to Tensorboard."
|
324 |
+
gradient_nps = [to_np(x.data) for x in self.gradients]
|
325 |
+
num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps)
|
326 |
+
self._add_gradient_scalar('num_zeros', scalar_value=num_zeros)
|
327 |
+
|
328 |
+
def _write_avg_gradient(self)->None:
|
329 |
+
"Writes the average of the gradients to Tensorboard."
|
330 |
+
avg_gradient = sum(x.data.mean() for x in self.gradients)/len(self.gradients)
|
331 |
+
self._add_gradient_scalar('avg_gradient', scalar_value=avg_gradient)
|
332 |
+
|
333 |
+
def _write_median_gradient(self)->None:
|
334 |
+
"Writes the median of the gradients to Tensorboard."
|
335 |
+
median_gradient = statistics.median(x.data.median() for x in self.gradients)
|
336 |
+
self._add_gradient_scalar('median_gradient', scalar_value=median_gradient)
|
337 |
+
|
338 |
+
def _write_max_gradient(self)->None:
|
339 |
+
"Writes the maximum of the gradients to Tensorboard."
|
340 |
+
max_gradient = max(x.data.max() for x in self.gradients)
|
341 |
+
self._add_gradient_scalar('max_gradient', scalar_value=max_gradient)
|
342 |
+
|
343 |
+
def _write_min_gradient(self)->None:
|
344 |
+
"Writes the minimum of the gradients to Tensorboard."
|
345 |
+
min_gradient = min(x.data.min() for x in self.gradients)
|
346 |
+
self._add_gradient_scalar('min_gradient', scalar_value=min_gradient)
|
347 |
+
|
348 |
+
def write(self)->None:
|
349 |
+
"Writes model gradient statistics to Tensorboard."
|
350 |
+
if len(self.gradients) == 0: return
|
351 |
+
norms = [x.data.norm() for x in self.gradients]
|
352 |
+
self._write_avg_norm(norms=norms)
|
353 |
+
self._write_median_norm(norms=norms)
|
354 |
+
self._write_max_norm(norms=norms)
|
355 |
+
self._write_min_norm(norms=norms)
|
356 |
+
self._write_num_zeros()
|
357 |
+
self._write_avg_gradient()
|
358 |
+
self._write_median_gradient()
|
359 |
+
self._write_max_gradient()
|
360 |
+
self._write_min_gradient()
|
361 |
+
|
362 |
+
class ModelStatsTBWriter():
|
363 |
+
"Writes model gradient statistics to Tensorboard."
|
364 |
+
def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model_stats')->None:
|
365 |
+
"Writes model gradient statistics to Tensorboard."
|
366 |
+
request = ModelStatsTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
|
367 |
+
asyncTBWriter.request_write(request)
|
368 |
+
|
369 |
+
class ImageTBRequest(TBWriteRequest):
|
370 |
+
"Request object for model image output writes to Tensorboard."
|
371 |
+
def __init__(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
|
372 |
+
super().__init__(tbwriter=tbwriter, iteration=iteration)
|
373 |
+
self.image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)
|
374 |
+
self.ds_type = ds_type
|
375 |
+
|
376 |
+
def _write_images(self, name:str, images:[Tensor])->None:
|
377 |
+
"Writes list of images as tensors to Tensorboard."
|
378 |
+
tag = self.ds_type.name + ' ' + name
|
379 |
+
self.tbwriter.add_image(tag=tag, img_tensor=vutils.make_grid(images, normalize=True), global_step=self.iteration)
|
380 |
+
|
381 |
+
def _get_image_tensors(self)->([Tensor], [Tensor], [Tensor]):
|
382 |
+
"Gets list of image tensors from lists of Image objects, as a tuple of original, generated and real(target) images."
|
383 |
+
orig_images, gen_images, real_images = [], [], []
|
384 |
+
for image_set in self.image_sets:
|
385 |
+
orig_images.append(image_set.orig.px)
|
386 |
+
gen_images.append(image_set.gen.px)
|
387 |
+
real_images.append(image_set.real.px)
|
388 |
+
return orig_images, gen_images, real_images
|
389 |
+
|
390 |
+
def write(self)->None:
|
391 |
+
"Writes original, generated and real(target) images to Tensorboard."
|
392 |
+
orig_images, gen_images, real_images = self._get_image_tensors()
|
393 |
+
self._write_images(name='orig images', images=orig_images)
|
394 |
+
self._write_images(name='gen images', images=gen_images)
|
395 |
+
self._write_images(name='real images', images=real_images)
|
396 |
+
|
397 |
+
#If this isn't done async then this is noticeably slower
|
398 |
+
class ImageTBWriter():
|
399 |
+
"Writes model image output to Tensorboard."
|
400 |
+
def __init__(self): super().__init__()
|
401 |
+
|
402 |
+
def write(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iteration:int, tbwriter:SummaryWriter)->None:
|
403 |
+
"Writes training and validation batch images to Tensorboard."
|
404 |
+
self._write_for_dstype(learn=learn, batch=val_batch, iteration=iteration, tbwriter=tbwriter, ds_type=DatasetType.Valid)
|
405 |
+
self._write_for_dstype(learn=learn, batch=trn_batch, iteration=iteration, tbwriter=tbwriter, ds_type=DatasetType.Train)
|
406 |
+
|
407 |
+
def _write_for_dstype(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType)->None:
|
408 |
+
"Writes batch images of specified DatasetType to Tensorboard."
|
409 |
+
request = ImageTBRequest(learn=learn, batch=batch, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)
|
410 |
+
asyncTBWriter.request_write(request)
|
411 |
+
|
412 |
+
class GraphTBRequest(TBWriteRequest):
|
413 |
+
"Request object for model histogram writes to Tensorboard."
|
414 |
+
def __init__(self, model:nn.Module, tbwriter:SummaryWriter, input_to_model:torch.Tensor):
|
415 |
+
super().__init__(tbwriter=tbwriter, iteration=0)
|
416 |
+
self.model,self.input_to_model = model,input_to_model
|
417 |
+
|
418 |
+
def write(self)->None:
|
419 |
+
"Writes single model graph to Tensorboard."
|
420 |
+
self.tbwriter.add_graph(model=self.model, input_to_model=self.input_to_model)
|
421 |
+
|
422 |
+
class GraphTBWriter():
|
423 |
+
"Writes model network graph to Tensorboard."
|
424 |
+
def write(self, model:nn.Module, tbwriter:SummaryWriter, input_to_model:torch.Tensor)->None:
|
425 |
+
"Writes model graph to Tensorboard."
|
426 |
+
request = GraphTBRequest(model=model, tbwriter=tbwriter, input_to_model=input_to_model)
|
427 |
+
asyncTBWriter.request_write(request)
|
DeOldify/fastai/callbacks/tracker.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contribution from @fredguth, https://github.com/fredguth/fastai_playground.
|
2 |
+
|
3 |
+
from fastai.torch_core import *
|
4 |
+
from fastai.callback import *
|
5 |
+
from fastai.basic_train import *
|
6 |
+
|
7 |
+
__all__ = ['TerminateOnNaNCallback', 'EarlyStoppingCallback', 'SaveModelCallback', 'TrackerCallback',
|
8 |
+
'ReduceLROnPlateauCallback', 'TrackEpochCallback' ]
|
9 |
+
|
10 |
+
class TerminateOnNaNCallback(Callback):
|
11 |
+
"A `Callback` that terminates training if loss is NaN."
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
self.stop = False
|
15 |
+
|
16 |
+
def on_batch_end(self, last_loss, epoch, num_batch, **kwargs:Any)->None:
|
17 |
+
"Test if `last_loss` is NaN and interrupts training."
|
18 |
+
if self.stop: return True #to skip validation after stopping during training
|
19 |
+
if torch.isnan(last_loss):
|
20 |
+
print (f'Epoch/Batch ({epoch}/{num_batch}): Invalid loss, terminating training.')
|
21 |
+
return {'stop_epoch': True, 'stop_training': True, 'skip_validate': True}
|
22 |
+
|
23 |
+
class TrackerCallback(LearnerCallback):
|
24 |
+
"A `LearnerCallback` that keeps track of the best value in `monitor`."
|
25 |
+
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto'):
|
26 |
+
super().__init__(learn)
|
27 |
+
self.monitor,self.mode = monitor,mode
|
28 |
+
if self.mode not in ['auto', 'min', 'max']:
|
29 |
+
warn(f'{self.__class__} mode {self.mode} is invalid, falling back to "auto" mode.')
|
30 |
+
self.mode = 'auto'
|
31 |
+
mode_dict = {'min': np.less, 'max':np.greater}
|
32 |
+
mode_dict['auto'] = np.less if 'loss' in self.monitor else np.greater
|
33 |
+
self.operator = mode_dict[self.mode]
|
34 |
+
|
35 |
+
def on_train_begin(self, **kwargs:Any)->None:
|
36 |
+
"Initializes the best value."
|
37 |
+
self.best = float('inf') if self.operator == np.less else -float('inf')
|
38 |
+
|
39 |
+
def get_monitor_value(self):
|
40 |
+
"Pick the monitored value."
|
41 |
+
if self.monitor=='trn_loss' and len(self.learn.recorder.losses) == 0: return None
|
42 |
+
elif len(self.learn.recorder.val_losses) == 0: return None
|
43 |
+
values = {'train_loss':self.learn.recorder.losses[-1].cpu().numpy(),
|
44 |
+
'valid_loss':self.learn.recorder.val_losses[-1]}
|
45 |
+
if values['valid_loss'] is None: return
|
46 |
+
if self.learn.recorder.metrics:
|
47 |
+
for m, n in zip(self.learn.recorder.metrics[-1],self.learn.recorder.names[3:-1]):
|
48 |
+
values[n] = m
|
49 |
+
if values.get(self.monitor) is None:
|
50 |
+
warn(f'{self.__class__} conditioned on metric `{self.monitor}` which is not available. Available metrics are: {", ".join(map(str, self.learn.recorder.names[1:-1]))}')
|
51 |
+
return values.get(self.monitor)
|
52 |
+
|
53 |
+
class EarlyStoppingCallback(TrackerCallback):
|
54 |
+
"A `TrackerCallback` that terminates training when monitored quantity stops improving."
|
55 |
+
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', min_delta:int=0, patience:int=0):
|
56 |
+
super().__init__(learn, monitor=monitor, mode=mode)
|
57 |
+
self.min_delta,self.patience = min_delta,patience
|
58 |
+
if self.operator == np.less: self.min_delta *= -1
|
59 |
+
|
60 |
+
def on_train_begin(self, **kwargs:Any)->None:
|
61 |
+
"Initialize inner arguments."
|
62 |
+
self.wait = 0
|
63 |
+
super().on_train_begin(**kwargs)
|
64 |
+
|
65 |
+
def on_epoch_end(self, epoch, **kwargs:Any)->None:
|
66 |
+
"Compare the value monitored to its best score and maybe stop training."
|
67 |
+
current = self.get_monitor_value()
|
68 |
+
if current is None: return
|
69 |
+
if self.operator(current - self.min_delta, self.best):
|
70 |
+
self.best,self.wait = current,0
|
71 |
+
else:
|
72 |
+
self.wait += 1
|
73 |
+
if self.wait > self.patience:
|
74 |
+
print(f'Epoch {epoch}: early stopping')
|
75 |
+
return {"stop_training":True}
|
76 |
+
|
77 |
+
class SaveModelCallback(TrackerCallback):
|
78 |
+
"A `TrackerCallback` that saves the model when monitored quantity is best."
|
79 |
+
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', every:str='improvement', name:str='bestmodel'):
|
80 |
+
super().__init__(learn, monitor=monitor, mode=mode)
|
81 |
+
self.every,self.name = every,name
|
82 |
+
if self.every not in ['improvement', 'epoch']:
|
83 |
+
warn(f'SaveModel every {self.every} is invalid, falling back to "improvement".')
|
84 |
+
self.every = 'improvement'
|
85 |
+
|
86 |
+
def jump_to_epoch(self, epoch:int)->None:
|
87 |
+
try:
|
88 |
+
self.learn.load(f'{self.name}_{epoch-1}', purge=False)
|
89 |
+
print(f"Loaded {self.name}_{epoch-1}")
|
90 |
+
except: print(f'Model {self.name}_{epoch-1} not found.')
|
91 |
+
|
92 |
+
def on_epoch_end(self, epoch:int, **kwargs:Any)->None:
|
93 |
+
"Compare the value monitored to its best score and maybe save the model."
|
94 |
+
if self.every=="epoch": self.learn.save(f'{self.name}_{epoch}')
|
95 |
+
else: #every="improvement"
|
96 |
+
current = self.get_monitor_value()
|
97 |
+
if current is not None and self.operator(current, self.best):
|
98 |
+
print(f'Better model found at epoch {epoch} with {self.monitor} value: {current}.')
|
99 |
+
self.best = current
|
100 |
+
self.learn.save(f'{self.name}')
|
101 |
+
|
102 |
+
def on_train_end(self, **kwargs):
|
103 |
+
"Load the best model."
|
104 |
+
if self.every=="improvement" and (self.learn.path/f'{self.learn.model_dir}/{self.name}.pth').is_file():
|
105 |
+
self.learn.load(f'{self.name}', purge=False)
|
106 |
+
|
107 |
+
class ReduceLROnPlateauCallback(TrackerCallback):
|
108 |
+
"A `TrackerCallback` that reduces learning rate when a metric has stopped improving."
|
109 |
+
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', patience:int=0, factor:float=0.2,
|
110 |
+
min_delta:int=0):
|
111 |
+
super().__init__(learn, monitor=monitor, mode=mode)
|
112 |
+
self.patience,self.factor,self.min_delta = patience,factor,min_delta
|
113 |
+
if self.operator == np.less: self.min_delta *= -1
|
114 |
+
|
115 |
+
def on_train_begin(self, **kwargs:Any)->None:
|
116 |
+
"Initialize inner arguments."
|
117 |
+
self.wait, self.opt = 0, self.learn.opt
|
118 |
+
super().on_train_begin(**kwargs)
|
119 |
+
|
120 |
+
def on_epoch_end(self, epoch, **kwargs:Any)->None:
|
121 |
+
"Compare the value monitored to its best and maybe reduce lr."
|
122 |
+
current = self.get_monitor_value()
|
123 |
+
if current is None: return
|
124 |
+
if self.operator(current - self.min_delta, self.best): self.best,self.wait = current,0
|
125 |
+
else:
|
126 |
+
self.wait += 1
|
127 |
+
if self.wait > self.patience:
|
128 |
+
self.opt.lr *= self.factor
|
129 |
+
self.wait = 0
|
130 |
+
print(f'Epoch {epoch}: reducing lr to {self.opt.lr}')
|
131 |
+
|
132 |
+
|
133 |
+
class TrackEpochCallback(LearnerCallback):
|
134 |
+
_order = -20 #Need to run before fit_one_cycle
|
135 |
+
def __init__(self, learn:Learner, name:str='epoch', epoch_offset:int=None):
|
136 |
+
"Store completed epoch number in `learn.model_dir/name`."
|
137 |
+
super().__init__(learn)
|
138 |
+
learn._test_writeable_path()
|
139 |
+
self.path = learn.path/learn.model_dir/name
|
140 |
+
if epoch_offset is None:
|
141 |
+
if os.path.isfile(self.path):
|
142 |
+
with self.path.open('r') as f:
|
143 |
+
try: self.start_epoch = int(f.read())+1
|
144 |
+
except: self.start_epoch = 0
|
145 |
+
else: self.start_epoch = 0
|
146 |
+
|
147 |
+
def on_train_begin(self, **kwargs:Any):
|
148 |
+
return {'epoch': self.start_epoch}
|
149 |
+
|
150 |
+
def on_epoch_end(self, epoch, **kwargs:Any)->None:
|
151 |
+
with self.path.open('w') as f: f.write(f'{epoch}')
|
152 |
+
|
153 |
+
def restart(self): os.remove(self.path)
|
DeOldify/fastai/collab.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Module support for Collaborative Filtering"
|
2 |
+
from .tabular import *
|
3 |
+
from . import tabular
|
4 |
+
|
5 |
+
__all__ = [*tabular.__all__, 'EmbeddingDotBias', 'EmbeddingNN', 'collab_learner', 'CollabDataBunch', 'CollabLine',
|
6 |
+
'CollabList', 'CollabLearner']
|
7 |
+
|
8 |
+
class CollabProcessor(TabularProcessor):
|
9 |
+
"Subclass `TabularProcessor for `process_one`."
|
10 |
+
def process_one(self, item):
|
11 |
+
res = super().process_one(item)
|
12 |
+
return CollabLine(res.cats,res.conts,res.classes,res.names)
|
13 |
+
|
14 |
+
class CollabLine(TabularLine):
|
15 |
+
"Base item for collaborative filtering, subclasses `TabularLine`."
|
16 |
+
def __init__(self, cats, conts, classes, names):
|
17 |
+
super().__init__(cats, conts, classes, names)
|
18 |
+
self.data = [self.data[0][0],self.data[0][1]]
|
19 |
+
|
20 |
+
class CollabList(TabularList):
|
21 |
+
"Base `ItemList` for collaborative filtering, subclasses `TabularList`."
|
22 |
+
_item_cls,_label_cls,_processor = CollabLine,FloatList,CollabProcessor
|
23 |
+
|
24 |
+
def reconstruct(self, t:Tensor): return CollabLine(tensor(t), tensor([]), self.classes, self.col_names)
|
25 |
+
|
26 |
+
class EmbeddingNN(TabularModel):
|
27 |
+
"Subclass `TabularModel` to create a NN suitable for collaborative filtering."
|
28 |
+
def __init__(self, emb_szs:ListSizes, layers:Collection[int]=None, ps:Collection[float]=None,
|
29 |
+
emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, bn_final:bool=False):
|
30 |
+
super().__init__(emb_szs=emb_szs, n_cont=0, out_sz=1, layers=layers, ps=ps, emb_drop=emb_drop, y_range=y_range,
|
31 |
+
use_bn=use_bn, bn_final=bn_final)
|
32 |
+
|
33 |
+
def forward(self, users:LongTensor, items:LongTensor) -> Tensor:
|
34 |
+
return super().forward(torch.stack([users,items], dim=1), None)
|
35 |
+
|
36 |
+
class EmbeddingDotBias(Module):
|
37 |
+
"Base dot model for collaborative filtering."
|
38 |
+
def __init__(self, n_factors:int, n_users:int, n_items:int, y_range:Tuple[float,float]=None):
|
39 |
+
self.y_range = y_range
|
40 |
+
(self.u_weight, self.i_weight, self.u_bias, self.i_bias) = [embedding(*o) for o in [
|
41 |
+
(n_users, n_factors), (n_items, n_factors), (n_users,1), (n_items,1)
|
42 |
+
]]
|
43 |
+
|
44 |
+
def forward(self, users:LongTensor, items:LongTensor) -> Tensor:
|
45 |
+
dot = self.u_weight(users)* self.i_weight(items)
|
46 |
+
res = dot.sum(1) + self.u_bias(users).squeeze() + self.i_bias(items).squeeze()
|
47 |
+
if self.y_range is None: return res
|
48 |
+
return torch.sigmoid(res) * (self.y_range[1]-self.y_range[0]) + self.y_range[0]
|
49 |
+
|
50 |
+
class CollabDataBunch(DataBunch):
|
51 |
+
"Base `DataBunch` for collaborative filtering."
|
52 |
+
@classmethod
|
53 |
+
def from_df(cls, ratings:DataFrame, valid_pct:float=0.2, user_name:Optional[str]=None, item_name:Optional[str]=None,
|
54 |
+
rating_name:Optional[str]=None, test:DataFrame=None, seed:int=None, path:PathOrStr='.', bs:int=64,
|
55 |
+
val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,
|
56 |
+
device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False) -> 'CollabDataBunch':
|
57 |
+
"Create a `DataBunch` suitable for collaborative filtering from `ratings`."
|
58 |
+
user_name = ifnone(user_name, ratings.columns[0])
|
59 |
+
item_name = ifnone(item_name, ratings.columns[1])
|
60 |
+
rating_name = ifnone(rating_name,ratings.columns[2])
|
61 |
+
cat_names = [user_name,item_name]
|
62 |
+
src = (CollabList.from_df(ratings, cat_names=cat_names, procs=Categorify)
|
63 |
+
.split_by_rand_pct(valid_pct=valid_pct, seed=seed).label_from_df(cols=rating_name))
|
64 |
+
if test is not None: src.add_test(CollabList.from_df(test, cat_names=cat_names))
|
65 |
+
return src.databunch(path=path, bs=bs, val_bs=val_bs, num_workers=num_workers, device=device,
|
66 |
+
collate_fn=collate_fn, no_check=no_check)
|
67 |
+
|
68 |
+
class CollabLearner(Learner):
|
69 |
+
"`Learner` suitable for collaborative filtering."
|
70 |
+
def get_idx(self, arr:Collection, is_item:bool=True):
|
71 |
+
"Fetch item or user (based on `is_item`) for all in `arr`. (Set model to `cpu` and no grad.)"
|
72 |
+
m = self.model.eval().cpu()
|
73 |
+
requires_grad(m,False)
|
74 |
+
u_class,i_class = self.data.train_ds.x.classes.values()
|
75 |
+
classes = i_class if is_item else u_class
|
76 |
+
c2i = {v:k for k,v in enumerate(classes)}
|
77 |
+
try: return tensor([c2i[o] for o in arr])
|
78 |
+
except Exception as e:
|
79 |
+
print(f"""You're trying to access {'an item' if is_item else 'a user'} that isn't in the training data.
|
80 |
+
If it was in your original data, it may have been split such that it's only in the validation set now.""")
|
81 |
+
|
82 |
+
def bias(self, arr:Collection, is_item:bool=True):
|
83 |
+
"Bias for item or user (based on `is_item`) for all in `arr`. (Set model to `cpu` and no grad.)"
|
84 |
+
idx = self.get_idx(arr, is_item)
|
85 |
+
m = self.model
|
86 |
+
layer = m.i_bias if is_item else m.u_bias
|
87 |
+
return layer(idx).squeeze()
|
88 |
+
|
89 |
+
def weight(self, arr:Collection, is_item:bool=True):
|
90 |
+
"Bias for item or user (based on `is_item`) for all in `arr`. (Set model to `cpu` and no grad.)"
|
91 |
+
idx = self.get_idx(arr, is_item)
|
92 |
+
m = self.model
|
93 |
+
layer = m.i_weight if is_item else m.u_weight
|
94 |
+
return layer(idx)
|
95 |
+
|
96 |
+
def collab_learner(data, n_factors:int=None, use_nn:bool=False, emb_szs:Dict[str,int]=None, layers:Collection[int]=None,
|
97 |
+
ps:Collection[float]=None, emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True,
|
98 |
+
bn_final:bool=False, **learn_kwargs)->Learner:
|
99 |
+
"Create a Learner for collaborative filtering on `data`."
|
100 |
+
emb_szs = data.get_emb_szs(ifnone(emb_szs, {}))
|
101 |
+
u,m = data.train_ds.x.classes.values()
|
102 |
+
if use_nn: model = EmbeddingNN(emb_szs=emb_szs, layers=layers, ps=ps, emb_drop=emb_drop, y_range=y_range,
|
103 |
+
use_bn=use_bn, bn_final=bn_final, **learn_kwargs)
|
104 |
+
else: model = EmbeddingDotBias(n_factors, len(u), len(m), y_range=y_range)
|
105 |
+
return CollabLearner(data, model, **learn_kwargs)
|
106 |
+
|
DeOldify/fastai/core.py
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"`fastai.core` contains essential util functions to format and split data"
|
2 |
+
from .imports.core import *
|
3 |
+
|
4 |
+
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
|
5 |
+
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
|
6 |
+
|
7 |
+
AnnealFunc = Callable[[Number, Number, float], Number]
|
8 |
+
ArgStar = Collection[Any]
|
9 |
+
BatchSamples = Collection[Tuple[Collection[int], int]]
|
10 |
+
DataFrameOrChunks = Union[DataFrame, pd.io.parsers.TextFileReader]
|
11 |
+
FilePathList = Collection[Path]
|
12 |
+
Floats = Union[float, Collection[float]]
|
13 |
+
ImgLabel = str
|
14 |
+
ImgLabels = Collection[ImgLabel]
|
15 |
+
IntsOrStrs = Union[int, Collection[int], str, Collection[str]]
|
16 |
+
KeyFunc = Callable[[int], int]
|
17 |
+
KWArgs = Dict[str, Any]
|
18 |
+
ListOrItem = Union[Collection[Any], int, float, str]
|
19 |
+
ListRules = Collection[Callable[[str], str]]
|
20 |
+
ListSizes = Collection[Tuple[int, int]]
|
21 |
+
NPArrayableList = Collection[Union[np.ndarray, list]]
|
22 |
+
NPArrayList = Collection[np.ndarray]
|
23 |
+
NPArrayMask = np.ndarray
|
24 |
+
NPImage = np.ndarray
|
25 |
+
OptDataFrame = Optional[DataFrame]
|
26 |
+
OptListOrItem = Optional[ListOrItem]
|
27 |
+
OptRange = Optional[Tuple[float, float]]
|
28 |
+
OptStrTuple = Optional[Tuple[str, str]]
|
29 |
+
OptStats = Optional[Tuple[np.ndarray, np.ndarray]]
|
30 |
+
PathOrStr = Union[Path, str]
|
31 |
+
PathLikeOrBinaryStream = Union[PathOrStr, BufferedWriter, BytesIO]
|
32 |
+
PBar = Union[MasterBar, ProgressBar]
|
33 |
+
Point = Tuple[float, float]
|
34 |
+
Points = Collection[Point]
|
35 |
+
Sizes = List[List[int]]
|
36 |
+
SplitArrayList = List[Tuple[np.ndarray, np.ndarray]]
|
37 |
+
StartOptEnd = Union[float, Tuple[float, float]]
|
38 |
+
StrList = Collection[str]
|
39 |
+
Tokens = Collection[Collection[str]]
|
40 |
+
OptStrList = Optional[StrList]
|
41 |
+
np.set_printoptions(precision=6, threshold=50, edgeitems=4, linewidth=120)
|
42 |
+
|
43 |
+
|
44 |
+
def num_cpus() -> int:
|
45 |
+
"Get number of cpus"
|
46 |
+
try:
|
47 |
+
return len(os.sched_getaffinity(0))
|
48 |
+
except AttributeError:
|
49 |
+
return os.cpu_count()
|
50 |
+
|
51 |
+
|
52 |
+
_default_cpus = min(16, num_cpus())
|
53 |
+
defaults = SimpleNamespace(
|
54 |
+
cpus=_default_cpus, cmap='viridis', return_fig=False, silent=False)
|
55 |
+
|
56 |
+
|
57 |
+
def is_listy(x: Any) -> bool: return isinstance(x, (tuple, list))
|
58 |
+
def is_tuple(x: Any) -> bool: return isinstance(x, tuple)
|
59 |
+
def is_dict(x: Any) -> bool: return isinstance(x, dict)
|
60 |
+
def is_pathlike(x: Any) -> bool: return isinstance(x, (str, Path))
|
61 |
+
def noop(x): return x
|
62 |
+
|
63 |
+
|
64 |
+
class PrePostInitMeta(type):
|
65 |
+
"A metaclass that calls optional `__pre_init__` and `__post_init__` methods"
|
66 |
+
def __new__(cls, name, bases, dct):
|
67 |
+
x = super().__new__(cls, name, bases, dct)
|
68 |
+
old_init = x.__init__
|
69 |
+
def _pass(self): pass
|
70 |
+
|
71 |
+
@functools.wraps(old_init)
|
72 |
+
def _init(self, *args, **kwargs):
|
73 |
+
self.__pre_init__()
|
74 |
+
old_init(self, *args, **kwargs)
|
75 |
+
self.__post_init__()
|
76 |
+
x.__init__ = _init
|
77 |
+
if not hasattr(x, '__pre_init__'):
|
78 |
+
x.__pre_init__ = _pass
|
79 |
+
if not hasattr(x, '__post_init__'):
|
80 |
+
x.__post_init__ = _pass
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
84 |
+
def chunks(l: Collection, n: int) -> Iterable:
|
85 |
+
"Yield successive `n`-sized chunks from `l`."
|
86 |
+
for i in range(0, len(l), n):
|
87 |
+
yield l[i:i+n]
|
88 |
+
|
89 |
+
|
90 |
+
def recurse(func: Callable, x: Any, *args, **kwargs) -> Any:
|
91 |
+
if is_listy(x):
|
92 |
+
return [recurse(func, o, *args, **kwargs) for o in x]
|
93 |
+
if is_dict(x):
|
94 |
+
return {k: recurse(func, v, *args, **kwargs) for k, v in x.items()}
|
95 |
+
return func(x, *args, **kwargs)
|
96 |
+
|
97 |
+
|
98 |
+
def first_el(x: Any) -> Any:
|
99 |
+
"Recursively get the first element of `x`."
|
100 |
+
if is_listy(x):
|
101 |
+
return first_el(x[0])
|
102 |
+
if is_dict(x):
|
103 |
+
return first_el(x[list(x.keys())[0]])
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
def to_int(b: Any) -> Union[int, List[int]]:
|
108 |
+
"Recursively convert `b` to an int or list/dict of ints; raises exception if not convertible."
|
109 |
+
return recurse(lambda x: int(x), b)
|
110 |
+
|
111 |
+
|
112 |
+
def ifnone(a: Any, b: Any) -> Any:
|
113 |
+
"`a` if `a` is not None, otherwise `b`."
|
114 |
+
return b if a is None else a
|
115 |
+
|
116 |
+
|
117 |
+
def is1d(a: Collection) -> bool:
|
118 |
+
"Return `True` if `a` is one-dimensional"
|
119 |
+
return len(a.shape) == 1 if hasattr(a, 'shape') else len(np.array(a).shape) == 1
|
120 |
+
|
121 |
+
|
122 |
+
def uniqueify(x: Series, sort: bool = False) -> List:
|
123 |
+
"Return sorted unique values of `x`."
|
124 |
+
res = list(OrderedDict.fromkeys(x).keys())
|
125 |
+
if sort:
|
126 |
+
res.sort()
|
127 |
+
return res
|
128 |
+
|
129 |
+
|
130 |
+
def idx_dict(a):
|
131 |
+
"Create a dictionary value to index from `a`."
|
132 |
+
return {v: k for k, v in enumerate(a)}
|
133 |
+
|
134 |
+
|
135 |
+
def find_classes(folder: Path) -> FilePathList:
|
136 |
+
"List of label subdirectories in imagenet-style `folder`."
|
137 |
+
classes = [d for d in folder.iterdir()
|
138 |
+
if d.is_dir() and not d.name.startswith('.')]
|
139 |
+
assert (len(classes) > 0)
|
140 |
+
return sorted(classes, key=lambda d: d.name)
|
141 |
+
|
142 |
+
|
143 |
+
def arrays_split(mask: NPArrayMask, *arrs: NPArrayableList) -> SplitArrayList:
|
144 |
+
"Given `arrs` is [a,b,...] and `mask`index - return[(a[mask],a[~mask]),(b[mask],b[~mask]),...]."
|
145 |
+
assert all([len(arr) == len(arrs[0])
|
146 |
+
for arr in arrs]), 'All arrays should have same length'
|
147 |
+
mask = array(mask)
|
148 |
+
return list(zip(*[(a[mask], a[~mask]) for a in map(np.array, arrs)]))
|
149 |
+
|
150 |
+
|
151 |
+
def random_split(valid_pct: float, *arrs: NPArrayableList) -> SplitArrayList:
|
152 |
+
"Randomly split `arrs` with `valid_pct` ratio. good for creating validation set."
|
153 |
+
assert (valid_pct >= 0 and valid_pct <=
|
154 |
+
1), 'Validation set percentage should be between 0 and 1'
|
155 |
+
is_train = np.random.uniform(size=(len(arrs[0]),)) > valid_pct
|
156 |
+
return arrays_split(is_train, *arrs)
|
157 |
+
|
158 |
+
|
159 |
+
def listify(p: OptListOrItem = None, q: OptListOrItem = None):
|
160 |
+
"Make `p` listy and the same length as `q`."
|
161 |
+
if p is None:
|
162 |
+
p = []
|
163 |
+
elif isinstance(p, str):
|
164 |
+
p = [p]
|
165 |
+
elif not isinstance(p, Iterable):
|
166 |
+
p = [p]
|
167 |
+
# Rank 0 tensors in PyTorch are Iterable but don't have a length.
|
168 |
+
else:
|
169 |
+
try:
|
170 |
+
a = len(p)
|
171 |
+
except:
|
172 |
+
p = [p]
|
173 |
+
n = q if type(q) == int else len(p) if q is None else len(q)
|
174 |
+
if len(p) == 1:
|
175 |
+
p = p * n
|
176 |
+
assert len(p) == n, f'List len mismatch ({len(p)} vs {n})'
|
177 |
+
return list(p)
|
178 |
+
|
179 |
+
|
180 |
+
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
|
181 |
+
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
|
182 |
+
|
183 |
+
|
184 |
+
def camel2snake(name: str) -> str:
|
185 |
+
"Change `name` from camel to snake style."
|
186 |
+
s1 = re.sub(_camel_re1, r'\1_\2', name)
|
187 |
+
return re.sub(_camel_re2, r'\1_\2', s1).lower()
|
188 |
+
|
189 |
+
|
190 |
+
def even_mults(start: float, stop: float, n: int) -> np.ndarray:
|
191 |
+
"Build log-stepped array from `start` to `stop` in `n` steps."
|
192 |
+
mult = stop/start
|
193 |
+
step = mult**(1/(n-1))
|
194 |
+
return np.array([start*(step**i) for i in range(n)])
|
195 |
+
|
196 |
+
|
197 |
+
def extract_kwargs(names: Collection[str], kwargs: KWArgs):
|
198 |
+
"Extract the keys in `names` from the `kwargs`."
|
199 |
+
new_kwargs = {}
|
200 |
+
for arg_name in names:
|
201 |
+
if arg_name in kwargs:
|
202 |
+
arg_val = kwargs.pop(arg_name)
|
203 |
+
new_kwargs[arg_name] = arg_val
|
204 |
+
return new_kwargs, kwargs
|
205 |
+
|
206 |
+
|
207 |
+
def partition(a: Collection, sz: int) -> List[Collection]:
|
208 |
+
"Split iterables `a` in equal parts of size `sz`"
|
209 |
+
return [a[i:i+sz] for i in range(0, len(a), sz)]
|
210 |
+
|
211 |
+
|
212 |
+
def partition_by_cores(a: Collection, n_cpus: int) -> List[Collection]:
|
213 |
+
"Split data in `a` equally among `n_cpus` cores"
|
214 |
+
return partition(a, len(a)//n_cpus + 1)
|
215 |
+
|
216 |
+
|
217 |
+
def series2cat(df: DataFrame, *col_names):
|
218 |
+
"Categorifies the columns `col_names` in `df`."
|
219 |
+
for c in listify(col_names):
|
220 |
+
df[c] = df[c].astype('category').cat.as_ordered()
|
221 |
+
|
222 |
+
|
223 |
+
TfmList = Union[Callable, Collection[Callable]]
|
224 |
+
|
225 |
+
|
226 |
+
class ItemBase():
|
227 |
+
"Base item type in the fastai library."
|
228 |
+
|
229 |
+
def __init__(self, data: Any): self.data = self.obj = data
|
230 |
+
def __repr__(self) -> str: return f'{self.__class__.__name__} {str(self)}'
|
231 |
+
|
232 |
+
def show(self, ax: plt.Axes, **kwargs):
|
233 |
+
"Subclass this method if you want to customize the way this `ItemBase` is shown on `ax`."
|
234 |
+
ax.set_title(str(self))
|
235 |
+
|
236 |
+
def apply_tfms(self, tfms: Collection, **kwargs):
|
237 |
+
"Subclass this method if you want to apply data augmentation with `tfms` to this `ItemBase`."
|
238 |
+
if tfms:
|
239 |
+
raise Exception(
|
240 |
+
f"Not implemented: you can't apply transforms to this type of item ({self.__class__.__name__})")
|
241 |
+
return self
|
242 |
+
|
243 |
+
def __eq__(self, other): return recurse_eq(self.data, other.data)
|
244 |
+
|
245 |
+
|
246 |
+
def recurse_eq(arr1, arr2):
|
247 |
+
if is_listy(arr1):
|
248 |
+
return is_listy(arr2) and len(arr1) == len(arr2) and np.all([recurse_eq(x, y) for x, y in zip(arr1, arr2)])
|
249 |
+
else:
|
250 |
+
return np.all(np.atleast_1d(arr1 == arr2))
|
251 |
+
|
252 |
+
|
253 |
+
def download_url(url: str, dest: str, overwrite: bool = False, pbar: ProgressBar = None,
|
254 |
+
show_progress=True, chunk_size=1024*1024, timeout=4, retries=5) -> None:
|
255 |
+
"Download `url` to `dest` unless it exists and not `overwrite`."
|
256 |
+
if os.path.exists(dest) and not overwrite:
|
257 |
+
return
|
258 |
+
|
259 |
+
s = requests.Session()
|
260 |
+
s.mount('http://', requests.adapters.HTTPAdapter(max_retries=retries))
|
261 |
+
u = s.get(url, stream=True, timeout=timeout)
|
262 |
+
try:
|
263 |
+
file_size = int(u.headers["Content-Length"])
|
264 |
+
except:
|
265 |
+
show_progress = False
|
266 |
+
|
267 |
+
with open(dest, 'wb') as f:
|
268 |
+
nbytes = 0
|
269 |
+
if show_progress:
|
270 |
+
pbar = progress_bar(
|
271 |
+
range(file_size), auto_update=False, leave=False, parent=pbar)
|
272 |
+
try:
|
273 |
+
for chunk in u.iter_content(chunk_size=chunk_size):
|
274 |
+
nbytes += len(chunk)
|
275 |
+
if show_progress:
|
276 |
+
pbar.update(nbytes)
|
277 |
+
f.write(chunk)
|
278 |
+
except requests.exceptions.ConnectionError as e:
|
279 |
+
fname = url.split('/')[-1]
|
280 |
+
from fastai.datasets import Config
|
281 |
+
data_dir = Config().data_path()
|
282 |
+
timeout_txt = (f'\n Download of {url} has failed after {retries} retries\n'
|
283 |
+
f' Fix the download manually:\n'
|
284 |
+
f'$ mkdir -p {data_dir}\n'
|
285 |
+
f'$ cd {data_dir}\n'
|
286 |
+
f'$ wget -c {url}\n'
|
287 |
+
f'$ tar -zxvf {fname}\n\n'
|
288 |
+
f'And re-run your code once the download is successful\n')
|
289 |
+
print(timeout_txt)
|
290 |
+
import sys
|
291 |
+
sys.exit(1)
|
292 |
+
|
293 |
+
|
294 |
+
def range_of(x):
|
295 |
+
"Create a range from 0 to `len(x)`."
|
296 |
+
return list(range(len(x)))
|
297 |
+
|
298 |
+
|
299 |
+
def arange_of(x):
|
300 |
+
"Same as `range_of` but returns an array."
|
301 |
+
return np.arange(len(x))
|
302 |
+
|
303 |
+
|
304 |
+
Path.ls = lambda x: list(x.iterdir())
|
305 |
+
|
306 |
+
|
307 |
+
def join_path(fname: PathOrStr, path: PathOrStr = '.') -> Path:
|
308 |
+
"Return `Path(path)/Path(fname)`, `path` defaults to current dir."
|
309 |
+
return Path(path)/Path(fname)
|
310 |
+
|
311 |
+
|
312 |
+
def join_paths(fnames: FilePathList, path: PathOrStr = '.') -> Collection[Path]:
|
313 |
+
"Join `path` to every file name in `fnames`."
|
314 |
+
path = Path(path)
|
315 |
+
return [join_path(o, path) for o in fnames]
|
316 |
+
|
317 |
+
|
318 |
+
def loadtxt_str(path: PathOrStr) -> np.ndarray:
|
319 |
+
"Return `ndarray` of `str` of lines of text from `path`."
|
320 |
+
with open(path, 'r') as f:
|
321 |
+
lines = f.readlines()
|
322 |
+
return np.array([l.strip() for l in lines])
|
323 |
+
|
324 |
+
|
325 |
+
def save_texts(fname: PathOrStr, texts: Collection[str]):
|
326 |
+
"Save in `fname` the content of `texts`."
|
327 |
+
with open(fname, 'w') as f:
|
328 |
+
for t in texts:
|
329 |
+
f.write(f'{t}\n')
|
330 |
+
|
331 |
+
|
332 |
+
def df_names_to_idx(names: IntsOrStrs, df: DataFrame):
|
333 |
+
"Return the column indexes of `names` in `df`."
|
334 |
+
if not is_listy(names):
|
335 |
+
names = [names]
|
336 |
+
if isinstance(names[0], int):
|
337 |
+
return names
|
338 |
+
return [df.columns.get_loc(c) for c in names]
|
339 |
+
|
340 |
+
|
341 |
+
def one_hot(x: Collection[int], c: int):
|
342 |
+
"One-hot encode `x` with `c` classes."
|
343 |
+
res = np.zeros((c,), np.float32)
|
344 |
+
res[listify(x)] = 1.
|
345 |
+
return res
|
346 |
+
|
347 |
+
|
348 |
+
def index_row(a: Union[Collection, pd.DataFrame, pd.Series], idxs: Collection[int]) -> Any:
|
349 |
+
"Return the slice of `a` corresponding to `idxs`."
|
350 |
+
if a is None:
|
351 |
+
return a
|
352 |
+
if isinstance(a, (pd.DataFrame, pd.Series)):
|
353 |
+
res = a.iloc[idxs]
|
354 |
+
if isinstance(res, (pd.DataFrame, pd.Series)):
|
355 |
+
return res.copy()
|
356 |
+
return res
|
357 |
+
return a[idxs]
|
358 |
+
|
359 |
+
|
360 |
+
def func_args(func) -> bool:
|
361 |
+
"Return the arguments of `func`."
|
362 |
+
code = func.__code__
|
363 |
+
return code.co_varnames[:code.co_argcount]
|
364 |
+
|
365 |
+
|
366 |
+
def has_arg(func, arg) -> bool:
|
367 |
+
"Check if `func` accepts `arg`."
|
368 |
+
return arg in func_args(func)
|
369 |
+
|
370 |
+
|
371 |
+
def split_kwargs_by_func(kwargs, func):
|
372 |
+
"Split `kwargs` between those expected by `func` and the others."
|
373 |
+
args = func_args(func)
|
374 |
+
func_kwargs = {a: kwargs.pop(a) for a in args if a in kwargs}
|
375 |
+
return func_kwargs, kwargs
|
376 |
+
|
377 |
+
|
378 |
+
def array(a, dtype: type = None, **kwargs) -> np.ndarray:
|
379 |
+
"Same as `np.array` but also handles generators. `kwargs` are passed to `np.array` with `dtype`."
|
380 |
+
if not isinstance(a, collections.abc.Sized) and not getattr(a, '__array_interface__', False):
|
381 |
+
a = list(a)
|
382 |
+
if np.int_ == np.int32 and dtype is None and is_listy(a) and len(a) and isinstance(a[0], int):
|
383 |
+
dtype = np.int64
|
384 |
+
return np.array(a, dtype=dtype, **kwargs)
|
385 |
+
|
386 |
+
|
387 |
+
class EmptyLabel(ItemBase):
|
388 |
+
"Should be used for a dummy label."
|
389 |
+
|
390 |
+
def __init__(self): self.obj, self.data = 0, 0
|
391 |
+
def __str__(self): return ''
|
392 |
+
def __hash__(self): return hash(str(self))
|
393 |
+
|
394 |
+
|
395 |
+
class Category(ItemBase):
|
396 |
+
"Basic class for single classification labels."
|
397 |
+
|
398 |
+
def __init__(self, data, obj): self.data, self.obj = data, obj
|
399 |
+
def __int__(self): return int(self.data)
|
400 |
+
def __str__(self): return str(self.obj)
|
401 |
+
def __hash__(self): return hash(str(self))
|
402 |
+
|
403 |
+
|
404 |
+
class MultiCategory(ItemBase):
|
405 |
+
"Basic class for multi-classification labels."
|
406 |
+
|
407 |
+
def __init__(self, data, obj,
|
408 |
+
raw): self.data, self.obj, self.raw = data, obj, raw
|
409 |
+
|
410 |
+
def __str__(self): return ';'.join([str(o) for o in self.obj])
|
411 |
+
def __hash__(self): return hash(str(self))
|
412 |
+
|
413 |
+
|
414 |
+
class FloatItem(ItemBase):
|
415 |
+
"Basic class for float items."
|
416 |
+
|
417 |
+
def __init__(self, obj): self.data, self.obj = np.array(
|
418 |
+
obj).astype(np.float32), obj
|
419 |
+
|
420 |
+
def __str__(self): return str(self.obj)
|
421 |
+
def __hash__(self): return hash(str(self))
|
422 |
+
|
423 |
+
|
424 |
+
def _treat_html(o: str) -> str:
|
425 |
+
o = str(o)
|
426 |
+
to_replace = {'\n': '\\n', '<': '<', '>': '>', '&': '&'}
|
427 |
+
for k, v in to_replace.items():
|
428 |
+
o = o.replace(k, v)
|
429 |
+
return o
|
430 |
+
|
431 |
+
|
432 |
+
def text2html_table(items: Collection[Collection[str]]) -> str:
|
433 |
+
"Put the texts in `items` in an HTML table, `widths` are the widths of the columns in %."
|
434 |
+
html_code = f"""<table border="1" class="dataframe">"""
|
435 |
+
html_code += f""" <thead>\n <tr style="text-align: right;">\n"""
|
436 |
+
for i in items[0]:
|
437 |
+
html_code += f" <th>{_treat_html(i)}</th>"
|
438 |
+
html_code += f" </tr>\n </thead>\n <tbody>"
|
439 |
+
html_code += " <tbody>"
|
440 |
+
for line in items[1:]:
|
441 |
+
html_code += " <tr>"
|
442 |
+
for i in line:
|
443 |
+
html_code += f" <td>{_treat_html(i)}</td>"
|
444 |
+
html_code += " </tr>"
|
445 |
+
html_code += " </tbody>\n</table>"
|
446 |
+
return html_code
|
447 |
+
|
448 |
+
|
449 |
+
def parallel(func, arr: Collection, max_workers: int = None, leave=False):
|
450 |
+
"Call `func` on every element of `arr` in parallel using `max_workers`."
|
451 |
+
max_workers = ifnone(max_workers, defaults.cpus)
|
452 |
+
if max_workers < 2:
|
453 |
+
results = [func(o, i) for i, o in progress_bar(
|
454 |
+
enumerate(arr), total=len(arr), leave=leave)]
|
455 |
+
else:
|
456 |
+
with ProcessPoolExecutor(max_workers=max_workers) as ex:
|
457 |
+
futures = [ex.submit(func, o, i) for i, o in enumerate(arr)]
|
458 |
+
results = []
|
459 |
+
for f in progress_bar(concurrent.futures.as_completed(futures), total=len(arr), leave=leave):
|
460 |
+
results.append(f.result())
|
461 |
+
if any([o is not None for o in results]):
|
462 |
+
return results
|
463 |
+
|
464 |
+
|
465 |
+
def subplots(rows: int, cols: int, imgsize: int = 4, figsize: Optional[Tuple[int, int]] = None, title=None, **kwargs):
|
466 |
+
"Like `plt.subplots` but with consistent axs shape, `kwargs` passed to `fig.suptitle` with `title`"
|
467 |
+
figsize = ifnone(figsize, (imgsize*cols, imgsize*rows))
|
468 |
+
fig, axs = plt.subplots(rows, cols, figsize=figsize)
|
469 |
+
if rows == cols == 1:
|
470 |
+
axs = [[axs]] # subplots(1,1) returns Axes, not [Axes]
|
471 |
+
elif (rows == 1 and cols != 1) or (cols == 1 and rows != 1):
|
472 |
+
axs = [axs]
|
473 |
+
if title is not None:
|
474 |
+
fig.suptitle(title, **kwargs)
|
475 |
+
return array(axs)
|
476 |
+
|
477 |
+
|
478 |
+
def show_some(items: Collection, n_max: int = 5, sep: str = ','):
|
479 |
+
"Return the representation of the first `n_max` elements in `items`."
|
480 |
+
if items is None or len(items) == 0:
|
481 |
+
return ''
|
482 |
+
res = sep.join([f'{o}' for o in items[:n_max]])
|
483 |
+
if len(items) > n_max:
|
484 |
+
res += '...'
|
485 |
+
return res
|
486 |
+
|
487 |
+
|
488 |
+
def get_tmp_file(dir=None):
|
489 |
+
"Create and return a tmp filename, optionally at a specific path. `os.remove` when done with it."
|
490 |
+
with tempfile.NamedTemporaryFile(delete=False, dir=dir) as f:
|
491 |
+
return f.name
|
492 |
+
|
493 |
+
|
494 |
+
def compose(funcs: List[Callable]) -> Callable:
|
495 |
+
"Compose `funcs`"
|
496 |
+
def compose_(funcs, x, *args, **kwargs):
|
497 |
+
for f in listify(funcs):
|
498 |
+
x = f(x, *args, **kwargs)
|
499 |
+
return x
|
500 |
+
return partial(compose_, funcs)
|
501 |
+
|
502 |
+
|
503 |
+
class PrettyString(str):
|
504 |
+
"Little hack to get strings to show properly in Jupyter."
|
505 |
+
|
506 |
+
def __repr__(self): return self
|
507 |
+
|
508 |
+
|
509 |
+
def float_or_x(x):
|
510 |
+
"Tries to convert to float, returns x if it can't"
|
511 |
+
try:
|
512 |
+
return float(x)
|
513 |
+
except:
|
514 |
+
return x
|
515 |
+
|
516 |
+
|
517 |
+
def bunzip(fn: PathOrStr):
|
518 |
+
"bunzip `fn`, raising exception if output already exists"
|
519 |
+
fn = Path(fn)
|
520 |
+
assert fn.exists(), f"{fn} doesn't exist"
|
521 |
+
out_fn = fn.with_suffix('')
|
522 |
+
assert not out_fn.exists(), f"{out_fn} already exists"
|
523 |
+
with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst:
|
524 |
+
for d in iter(lambda: src.read(1024*1024), b''): dst.write(d)
|
525 |
+
|
526 |
+
|
527 |
+
@contextmanager
|
528 |
+
def working_directory(path: PathOrStr):
|
529 |
+
"Change working directory to `path` and return to previous on exit."
|
530 |
+
prev_cwd = Path.cwd()
|
531 |
+
os.chdir(path)
|
532 |
+
try:
|
533 |
+
yield
|
534 |
+
finally:
|
535 |
+
os.chdir(prev_cwd)
|
DeOldify/fastai/data_block.py
ADDED
@@ -0,0 +1,804 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .torch_core import *
|
2 |
+
from .basic_data import *
|
3 |
+
from .layers import *
|
4 |
+
from numbers import Integral
|
5 |
+
|
6 |
+
__all__ = ['ItemList', 'CategoryList', 'MultiCategoryList', 'MultiCategoryProcessor', 'LabelList', 'ItemLists', 'get_files',
|
7 |
+
'PreProcessor', 'LabelLists', 'FloatList', 'CategoryProcessor', 'EmptyLabelList', 'MixedItem', 'MixedProcessor',
|
8 |
+
'MixedItemList']
|
9 |
+
|
10 |
+
def _decode(df):
|
11 |
+
return np.array([[df.columns[i] for i,t in enumerate(x) if t==1] for x in df.values], dtype=np.object)
|
12 |
+
|
13 |
+
def _maybe_squeeze(arr): return (arr if is1d(arr) else np.squeeze(arr))
|
14 |
+
|
15 |
+
def _path_to_same_str(p_fn):
|
16 |
+
"path -> str, but same on nt+posix, for alpha-sort only"
|
17 |
+
s_fn = str(p_fn)
|
18 |
+
s_fn = s_fn.replace('\\','.')
|
19 |
+
s_fn = s_fn.replace('/','.')
|
20 |
+
return s_fn
|
21 |
+
|
22 |
+
def _get_files(parent, p, f, extensions):
|
23 |
+
p = Path(p)#.relative_to(parent)
|
24 |
+
if isinstance(extensions,str): extensions = [extensions]
|
25 |
+
low_extensions = [e.lower() for e in extensions] if extensions is not None else None
|
26 |
+
res = [p/o for o in f if not o.startswith('.')
|
27 |
+
and (extensions is None or f'.{o.split(".")[-1].lower()}' in low_extensions)]
|
28 |
+
return res
|
29 |
+
|
30 |
+
def get_files(path:PathOrStr, extensions:Collection[str]=None, recurse:bool=False,
|
31 |
+
include:Optional[Collection[str]]=None, presort:bool=False)->FilePathList:
|
32 |
+
"Return list of files in `path` that have a suffix in `extensions`; optionally `recurse`."
|
33 |
+
if recurse:
|
34 |
+
res = []
|
35 |
+
for i,(p,d,f) in enumerate(os.walk(path)):
|
36 |
+
# skip hidden dirs
|
37 |
+
if include is not None and i==0: d[:] = [o for o in d if o in include]
|
38 |
+
else: d[:] = [o for o in d if not o.startswith('.')]
|
39 |
+
res += _get_files(path, p, f, extensions)
|
40 |
+
if presort: res = sorted(res, key=lambda p: _path_to_same_str(p), reverse=False)
|
41 |
+
return res
|
42 |
+
else:
|
43 |
+
f = [o.name for o in os.scandir(path) if o.is_file()]
|
44 |
+
res = _get_files(path, path, f, extensions)
|
45 |
+
if presort: res = sorted(res, key=lambda p: _path_to_same_str(p), reverse=False)
|
46 |
+
return res
|
47 |
+
|
48 |
+
class PreProcessor():
|
49 |
+
"Basic class for a processor that will be applied to items at the end of the data block API."
|
50 |
+
def __init__(self, ds:Collection=None): self.ref_ds = ds
|
51 |
+
def process_one(self, item:Any): return item
|
52 |
+
def process(self, ds:Collection): ds.items = array([self.process_one(item) for item in ds.items])
|
53 |
+
|
54 |
+
PreProcessors = Union[PreProcessor, Collection[PreProcessor]]
|
55 |
+
fastai_types[PreProcessors] = 'PreProcessors'
|
56 |
+
|
57 |
+
class ItemList():
|
58 |
+
"A collection of items with `__len__` and `__getitem__` with `ndarray` indexing semantics."
|
59 |
+
_bunch,_processor,_label_cls,_square_show,_square_show_res = DataBunch,None,None,False,False
|
60 |
+
|
61 |
+
def __init__(self, items:Iterator, path:PathOrStr='.', label_cls:Callable=None, inner_df:Any=None,
|
62 |
+
processor:PreProcessors=None, x:'ItemList'=None, ignore_empty:bool=False):
|
63 |
+
self.path = Path(path)
|
64 |
+
self.num_parts = len(self.path.parts)
|
65 |
+
self.items,self.x,self.ignore_empty = items,x,ignore_empty
|
66 |
+
if not isinstance(self.items,np.ndarray): self.items = array(self.items, dtype=object)
|
67 |
+
self.label_cls,self.inner_df,self.processor = ifnone(label_cls,self._label_cls),inner_df,processor
|
68 |
+
self._label_list,self._split = LabelList,ItemLists
|
69 |
+
self.copy_new = ['x', 'label_cls', 'path']
|
70 |
+
|
71 |
+
def __len__(self)->int: return len(self.items) or 1
|
72 |
+
def get(self, i)->Any:
|
73 |
+
"Subclass if you want to customize how to create item `i` from `self.items`."
|
74 |
+
return self.items[i]
|
75 |
+
def __repr__(self)->str:
|
76 |
+
items = [self[i] for i in range(min(5,len(self.items)))]
|
77 |
+
return f'{self.__class__.__name__} ({len(self.items)} items)\n{show_some(items)}\nPath: {self.path}'
|
78 |
+
|
79 |
+
def process(self, processor:PreProcessors=None):
|
80 |
+
"Apply `processor` or `self.processor` to `self`."
|
81 |
+
if processor is not None: self.processor = processor
|
82 |
+
self.processor = listify(self.processor)
|
83 |
+
for p in self.processor: p.process(self)
|
84 |
+
return self
|
85 |
+
|
86 |
+
def process_one(self, item:ItemBase, processor:PreProcessors=None):
|
87 |
+
"Apply `processor` or `self.processor` to `item`."
|
88 |
+
if processor is not None: self.processor = processor
|
89 |
+
self.processor = listify(self.processor)
|
90 |
+
for p in self.processor: item = p.process_one(item)
|
91 |
+
return item
|
92 |
+
|
93 |
+
def analyze_pred(self, pred:Tensor):
|
94 |
+
"Called on `pred` before `reconstruct` for additional preprocessing."
|
95 |
+
return pred
|
96 |
+
|
97 |
+
def reconstruct(self, t:Tensor, x:Tensor=None):
|
98 |
+
"Reconstruct one of the underlying item for its data `t`."
|
99 |
+
return self[0].reconstruct(t,x) if has_arg(self[0].reconstruct, 'x') else self[0].reconstruct(t)
|
100 |
+
|
101 |
+
def new(self, items:Iterator, processor:PreProcessors=None, **kwargs)->'ItemList':
|
102 |
+
"Create a new `ItemList` from `items`, keeping the same attributes."
|
103 |
+
processor = ifnone(processor, self.processor)
|
104 |
+
copy_d = {o:getattr(self,o) for o in self.copy_new}
|
105 |
+
kwargs = {**copy_d, **kwargs}
|
106 |
+
return self.__class__(items=items, processor=processor, **kwargs)
|
107 |
+
|
108 |
+
def add(self, items:'ItemList'):
|
109 |
+
self.items = np.concatenate([self.items, items.items], 0)
|
110 |
+
if self.inner_df is not None and items.inner_df is not None:
|
111 |
+
self.inner_df = pd.concat([self.inner_df, items.inner_df])
|
112 |
+
else: self.inner_df = self.inner_df or items.inner_df
|
113 |
+
return self
|
114 |
+
|
115 |
+
def __getitem__(self,idxs:int)->Any:
|
116 |
+
"returns a single item based if `idxs` is an integer or a new `ItemList` object if `idxs` is a range."
|
117 |
+
idxs = try_int(idxs)
|
118 |
+
if isinstance(idxs, Integral): return self.get(idxs)
|
119 |
+
else: return self.new(self.items[idxs], inner_df=index_row(self.inner_df, idxs))
|
120 |
+
|
121 |
+
@classmethod
|
122 |
+
def from_folder(cls, path:PathOrStr, extensions:Collection[str]=None, recurse:bool=True,
|
123 |
+
include:Optional[Collection[str]]=None, processor:PreProcessors=None, presort:Optional[bool]=False, **kwargs)->'ItemList':
|
124 |
+
"""Create an `ItemList` in `path` from the filenames that have a suffix in `extensions`.
|
125 |
+
`recurse` determines if we search subfolders."""
|
126 |
+
path = Path(path)
|
127 |
+
return cls(get_files(path, extensions, recurse=recurse, include=include, presort=presort), path=path, processor=processor, **kwargs)
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def from_df(cls, df:DataFrame, path:PathOrStr='.', cols:IntsOrStrs=0, processor:PreProcessors=None, **kwargs)->'ItemList':
|
131 |
+
"Create an `ItemList` in `path` from the inputs in the `cols` of `df`."
|
132 |
+
inputs = df.iloc[:,df_names_to_idx(cols, df)]
|
133 |
+
assert not inputs.isna().any().any(), f"You have NaN values in column(s) {cols} of your dataframe, please fix it."
|
134 |
+
res = cls(items=_maybe_squeeze(inputs.values), path=path, inner_df=df, processor=processor, **kwargs)
|
135 |
+
return res
|
136 |
+
|
137 |
+
@classmethod
|
138 |
+
def from_csv(cls, path:PathOrStr, csv_name:str, cols:IntsOrStrs=0, delimiter:str=None, header:str='infer',
|
139 |
+
processor:PreProcessors=None, **kwargs)->'ItemList':
|
140 |
+
"""Create an `ItemList` in `path` from the inputs in the `cols` of `path/csv_name`"""
|
141 |
+
df = pd.read_csv(Path(path)/csv_name, delimiter=delimiter, header=header)
|
142 |
+
return cls.from_df(df, path=path, cols=cols, processor=processor, **kwargs)
|
143 |
+
|
144 |
+
def _relative_item_path(self, i): return self.items[i].relative_to(self.path)
|
145 |
+
def _relative_item_paths(self): return [self._relative_item_path(i) for i in range_of(self.items)]
|
146 |
+
|
147 |
+
def use_partial_data(self, sample_pct:float=0.01, seed:int=None)->'ItemList':
|
148 |
+
"Use only a sample of `sample_pct`of the full dataset and an optional `seed`."
|
149 |
+
if seed is not None: np.random.seed(seed)
|
150 |
+
rand_idx = np.random.permutation(range_of(self))
|
151 |
+
cut = int(sample_pct * len(self))
|
152 |
+
return self[rand_idx[:cut]]
|
153 |
+
|
154 |
+
def to_text(self, fn:str):
|
155 |
+
"Save `self.items` to `fn` in `self.path`."
|
156 |
+
with open(self.path/fn, 'w') as f: f.writelines([f'{o}\n' for o in self._relative_item_paths()])
|
157 |
+
|
158 |
+
def filter_by_func(self, func:Callable)->'ItemList':
|
159 |
+
"Only keep elements for which `func` returns `True`."
|
160 |
+
self.items = array([o for o in self.items if func(o)])
|
161 |
+
return self
|
162 |
+
|
163 |
+
def filter_by_folder(self, include=None, exclude=None):
|
164 |
+
"Only keep filenames in `include` folder or reject the ones in `exclude`."
|
165 |
+
include,exclude = listify(include),listify(exclude)
|
166 |
+
def _inner(o):
|
167 |
+
if isinstance(o, Path): n = o.relative_to(self.path).parts[0]
|
168 |
+
else: n = o.split(os.path.sep)[len(str(self.path).split(os.path.sep))]
|
169 |
+
if include and not n in include: return False
|
170 |
+
if exclude and n in exclude: return False
|
171 |
+
return True
|
172 |
+
return self.filter_by_func(_inner)
|
173 |
+
|
174 |
+
def filter_by_rand(self, p:float, seed:int=None):
|
175 |
+
"Keep random sample of `items` with probability `p` and an optional `seed`."
|
176 |
+
if seed is not None: set_all_seed(seed)
|
177 |
+
return self.filter_by_func(lambda o: rand_bool(p))
|
178 |
+
|
179 |
+
def no_split(self):
|
180 |
+
warn("`no_split` is deprecated, please use `split_none`.")
|
181 |
+
return self.split_none()
|
182 |
+
|
183 |
+
def split_none(self):
|
184 |
+
"Don't split the data and create an empty validation set."
|
185 |
+
val = self[[]]
|
186 |
+
val.ignore_empty = True
|
187 |
+
return self._split(self.path, self, val)
|
188 |
+
|
189 |
+
def split_by_list(self, train, valid):
|
190 |
+
"Split the data between `train` and `valid`."
|
191 |
+
return self._split(self.path, train, valid)
|
192 |
+
|
193 |
+
def split_by_idxs(self, train_idx, valid_idx):
|
194 |
+
"Split the data between `train_idx` and `valid_idx`."
|
195 |
+
return self.split_by_list(self[train_idx], self[valid_idx])
|
196 |
+
|
197 |
+
def split_by_idx(self, valid_idx:Collection[int])->'ItemLists':
|
198 |
+
"Split the data according to the indexes in `valid_idx`."
|
199 |
+
#train_idx = [i for i in range_of(self.items) if i not in valid_idx]
|
200 |
+
train_idx = np.setdiff1d(arange_of(self.items), valid_idx)
|
201 |
+
return self.split_by_idxs(train_idx, valid_idx)
|
202 |
+
|
203 |
+
def _get_by_folder(self, name):
|
204 |
+
return [i for i in range_of(self) if (self.items[i].parts[self.num_parts] if isinstance(self.items[i], Path)
|
205 |
+
else self.items[i].split(os.path.sep)[0]) == name ]
|
206 |
+
|
207 |
+
def split_by_folder(self, train:str='train', valid:str='valid')->'ItemLists':
|
208 |
+
"Split the data depending on the folder (`train` or `valid`) in which the filenames are."
|
209 |
+
return self.split_by_idxs(self._get_by_folder(train), self._get_by_folder(valid))
|
210 |
+
|
211 |
+
def random_split_by_pct(self, valid_pct:float=0.2, seed:int=None):
|
212 |
+
warn("`random_split_by_pct` is deprecated, please use `split_by_rand_pct`.")
|
213 |
+
return self.split_by_rand_pct(valid_pct=valid_pct, seed=seed)
|
214 |
+
|
215 |
+
def split_by_rand_pct(self, valid_pct:float=0.2, seed:int=None)->'ItemLists':
|
216 |
+
"Split the items randomly by putting `valid_pct` in the validation set, optional `seed` can be passed."
|
217 |
+
if valid_pct==0.: return self.split_none()
|
218 |
+
if seed is not None: np.random.seed(seed)
|
219 |
+
rand_idx = np.random.permutation(range_of(self))
|
220 |
+
cut = int(valid_pct * len(self))
|
221 |
+
return self.split_by_idx(rand_idx[:cut])
|
222 |
+
|
223 |
+
def split_subsets(self, train_size:float, valid_size:float, seed=None) -> 'ItemLists':
|
224 |
+
"Split the items into train set with size `train_size * n` and valid set with size `valid_size * n`."
|
225 |
+
assert 0 < train_size < 1
|
226 |
+
assert 0 < valid_size < 1
|
227 |
+
assert train_size + valid_size <= 1.
|
228 |
+
if seed is not None: np.random.seed(seed)
|
229 |
+
n = len(self.items)
|
230 |
+
rand_idx = np.random.permutation(range(n))
|
231 |
+
train_cut, valid_cut = int(train_size * n), int(valid_size * n)
|
232 |
+
return self.split_by_idxs(rand_idx[:train_cut], rand_idx[-valid_cut:])
|
233 |
+
|
234 |
+
def split_by_valid_func(self, func:Callable)->'ItemLists':
|
235 |
+
"Split the data by result of `func` (which returns `True` for validation set)."
|
236 |
+
valid_idx = [i for i,o in enumerate(self.items) if func(o)]
|
237 |
+
return self.split_by_idx(valid_idx)
|
238 |
+
|
239 |
+
def split_by_files(self, valid_names:'ItemList')->'ItemLists':
|
240 |
+
"Split the data by using the names in `valid_names` for validation."
|
241 |
+
if isinstance(self.items[0], Path): return self.split_by_valid_func(lambda o: o.name in valid_names)
|
242 |
+
else: return self.split_by_valid_func(lambda o: os.path.basename(o) in valid_names)
|
243 |
+
|
244 |
+
def split_by_fname_file(self, fname:PathOrStr, path:PathOrStr=None)->'ItemLists':
|
245 |
+
"Split the data by using the names in `fname` for the validation set. `path` will override `self.path`."
|
246 |
+
path = Path(ifnone(path, self.path))
|
247 |
+
valid_names = loadtxt_str(path/fname)
|
248 |
+
return self.split_by_files(valid_names)
|
249 |
+
|
250 |
+
def split_from_df(self, col:IntsOrStrs=2):
|
251 |
+
"Split the data from the `col` in the dataframe in `self.inner_df`."
|
252 |
+
valid_idx = np.where(self.inner_df.iloc[:,df_names_to_idx(col, self.inner_df)])[0]
|
253 |
+
return self.split_by_idx(valid_idx)
|
254 |
+
|
255 |
+
def get_label_cls(self, labels, label_cls:Callable=None, label_delim:str=None, **kwargs):
|
256 |
+
"Return `label_cls` or guess one from the first element of `labels`."
|
257 |
+
if label_cls is not None: return label_cls
|
258 |
+
if self.label_cls is not None: return self.label_cls
|
259 |
+
if label_delim is not None: return MultiCategoryList
|
260 |
+
it = index_row(labels,0)
|
261 |
+
if isinstance(it, (float, np.float32)): return FloatList
|
262 |
+
if isinstance(try_int(it), (str, Integral)): return CategoryList
|
263 |
+
if isinstance(it, Collection): return MultiCategoryList
|
264 |
+
return ItemList #self.__class__
|
265 |
+
|
266 |
+
def _label_from_list(self, labels:Iterator, label_cls:Callable=None, from_item_lists:bool=False, **kwargs)->'LabelList':
|
267 |
+
"Label `self.items` with `labels`."
|
268 |
+
if not from_item_lists:
|
269 |
+
raise Exception("Your data isn't split, if you don't want a validation set, please use `split_none`.")
|
270 |
+
labels = array(labels, dtype=object)
|
271 |
+
label_cls = self.get_label_cls(labels, label_cls=label_cls, **kwargs)
|
272 |
+
y = label_cls(labels, path=self.path, **kwargs)
|
273 |
+
res = self._label_list(x=self, y=y)
|
274 |
+
return res
|
275 |
+
|
276 |
+
def label_from_df(self, cols:IntsOrStrs=1, label_cls:Callable=None, **kwargs):
|
277 |
+
"Label `self.items` from the values in `cols` in `self.inner_df`."
|
278 |
+
labels = self.inner_df.iloc[:,df_names_to_idx(cols, self.inner_df)]
|
279 |
+
assert labels.isna().sum().sum() == 0, f"You have NaN values in column(s) {cols} of your dataframe, please fix it."
|
280 |
+
if is_listy(cols) and len(cols) > 1 and (label_cls is None or label_cls == MultiCategoryList):
|
281 |
+
new_kwargs,label_cls = dict(one_hot=True, classes= cols),MultiCategoryList
|
282 |
+
kwargs = {**new_kwargs, **kwargs}
|
283 |
+
return self._label_from_list(_maybe_squeeze(labels), label_cls=label_cls, **kwargs)
|
284 |
+
|
285 |
+
def label_const(self, const:Any=0, label_cls:Callable=None, **kwargs)->'LabelList':
|
286 |
+
"Label every item with `const`."
|
287 |
+
return self.label_from_func(func=lambda o: const, label_cls=label_cls, **kwargs)
|
288 |
+
|
289 |
+
def label_empty(self, **kwargs):
|
290 |
+
"Label every item with an `EmptyLabel`."
|
291 |
+
kwargs['label_cls'] = EmptyLabelList
|
292 |
+
return self.label_from_func(func=lambda o: 0., **kwargs)
|
293 |
+
|
294 |
+
def label_from_func(self, func:Callable, label_cls:Callable=None, **kwargs)->'LabelList':
|
295 |
+
"Apply `func` to every input to get its label."
|
296 |
+
return self._label_from_list([func(o) for o in self.items], label_cls=label_cls, **kwargs)
|
297 |
+
|
298 |
+
def label_from_folder(self, label_cls:Callable=None, **kwargs)->'LabelList':
|
299 |
+
"Give a label to each filename depending on its folder."
|
300 |
+
return self.label_from_func(func=lambda o: (o.parts if isinstance(o, Path) else o.split(os.path.sep))[-2],
|
301 |
+
label_cls=label_cls, **kwargs)
|
302 |
+
|
303 |
+
def label_from_re(self, pat:str, full_path:bool=False, label_cls:Callable=None, **kwargs)->'LabelList':
|
304 |
+
"Apply the re in `pat` to determine the label of every filename. If `full_path`, search in the full name."
|
305 |
+
pat = re.compile(pat)
|
306 |
+
def _inner(o):
|
307 |
+
s = str((os.path.join(self.path,o) if full_path else o).as_posix())
|
308 |
+
res = pat.search(s)
|
309 |
+
assert res,f'Failed to find "{pat}" in "{s}"'
|
310 |
+
return res.group(1)
|
311 |
+
return self.label_from_func(_inner, label_cls=label_cls, **kwargs)
|
312 |
+
|
313 |
+
def databunch(self, **kwargs):
|
314 |
+
"To throw a clear error message when the data wasn't split and labeled."
|
315 |
+
raise Exception("Your data is neither split nor labeled, can't turn it into a `DataBunch` yet.")
|
316 |
+
|
317 |
+
class EmptyLabelList(ItemList):
|
318 |
+
"Basic `ItemList` for dummy labels."
|
319 |
+
def get(self, i): return EmptyLabel()
|
320 |
+
def reconstruct(self, t:Tensor, x:Tensor=None):
|
321 |
+
if len(t.size()) == 0: return EmptyLabel()
|
322 |
+
return self.x.reconstruct(t,x) if has_arg(self.x.reconstruct, 'x') else self.x.reconstruct(t)
|
323 |
+
|
324 |
+
class CategoryProcessor(PreProcessor):
|
325 |
+
"`PreProcessor` that create `classes` from `ds.items` and handle the mapping."
|
326 |
+
def __init__(self, ds:ItemList):
|
327 |
+
self.create_classes(ds.classes)
|
328 |
+
self.state_attrs,self.warns = ['classes'],[]
|
329 |
+
|
330 |
+
def create_classes(self, classes):
|
331 |
+
self.classes = classes
|
332 |
+
if classes is not None: self.c2i = {v:k for k,v in enumerate(classes)}
|
333 |
+
|
334 |
+
def generate_classes(self, items):
|
335 |
+
"Generate classes from `items` by taking the sorted unique values."
|
336 |
+
return uniqueify(items, sort=True)
|
337 |
+
|
338 |
+
def process_one(self,item):
|
339 |
+
if isinstance(item, EmptyLabel): return item
|
340 |
+
res = self.c2i.get(item,None)
|
341 |
+
if res is None: self.warns.append(str(item))
|
342 |
+
return res
|
343 |
+
|
344 |
+
def process(self, ds):
|
345 |
+
if self.classes is None: self.create_classes(self.generate_classes(ds.items))
|
346 |
+
ds.classes = self.classes
|
347 |
+
ds.c2i = self.c2i
|
348 |
+
super().process(ds)
|
349 |
+
|
350 |
+
def __getstate__(self): return {n:getattr(self,n) for n in self.state_attrs}
|
351 |
+
def __setstate__(self, state:dict):
|
352 |
+
self.create_classes(state['classes'])
|
353 |
+
self.state_attrs = state.keys()
|
354 |
+
for n in state.keys():
|
355 |
+
if n!='classes': setattr(self, n, state[n])
|
356 |
+
|
357 |
+
class CategoryListBase(ItemList):
|
358 |
+
"Basic `ItemList` for classification."
|
359 |
+
def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
|
360 |
+
self.classes=classes
|
361 |
+
self.filter_missing_y = True
|
362 |
+
super().__init__(items, **kwargs)
|
363 |
+
self.copy_new.append('classes')
|
364 |
+
|
365 |
+
@property
|
366 |
+
def c(self): return len(self.classes)
|
367 |
+
|
368 |
+
class CategoryList(CategoryListBase):
|
369 |
+
"Basic `ItemList` for single classification labels."
|
370 |
+
_processor=CategoryProcessor
|
371 |
+
def __init__(self, items:Iterator, classes:Collection=None, label_delim:str=None, **kwargs):
|
372 |
+
super().__init__(items, classes=classes, **kwargs)
|
373 |
+
self.loss_func = CrossEntropyFlat()
|
374 |
+
|
375 |
+
def get(self, i):
|
376 |
+
o = self.items[i]
|
377 |
+
if o is None: return None
|
378 |
+
return Category(o, self.classes[o])
|
379 |
+
|
380 |
+
def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax()
|
381 |
+
|
382 |
+
def reconstruct(self, t):
|
383 |
+
return Category(t, self.classes[t])
|
384 |
+
|
385 |
+
class MultiCategoryProcessor(CategoryProcessor):
|
386 |
+
"`PreProcessor` that create `classes` from `ds.items` and handle the mapping."
|
387 |
+
def __init__(self, ds:ItemList, one_hot:bool=False):
|
388 |
+
super().__init__(ds)
|
389 |
+
self.one_hot = one_hot
|
390 |
+
self.state_attrs.append('one_hot')
|
391 |
+
|
392 |
+
def process_one(self,item):
|
393 |
+
if self.one_hot or isinstance(item, EmptyLabel): return item
|
394 |
+
res = [super(MultiCategoryProcessor, self).process_one(o) for o in item]
|
395 |
+
return [r for r in res if r is not None]
|
396 |
+
|
397 |
+
def generate_classes(self, items):
|
398 |
+
"Generate classes from `items` by taking the sorted unique values."
|
399 |
+
classes = set()
|
400 |
+
for c in items: classes = classes.union(set(c))
|
401 |
+
classes = list(classes)
|
402 |
+
classes.sort()
|
403 |
+
return classes
|
404 |
+
|
405 |
+
class MultiCategoryList(CategoryListBase):
|
406 |
+
"Basic `ItemList` for multi-classification labels."
|
407 |
+
_processor=MultiCategoryProcessor
|
408 |
+
def __init__(self, items:Iterator, classes:Collection=None, label_delim:str=None, one_hot:bool=False, **kwargs):
|
409 |
+
if label_delim is not None: items = array(csv.reader(items.astype(str), delimiter=label_delim))
|
410 |
+
super().__init__(items, classes=classes, **kwargs)
|
411 |
+
if one_hot:
|
412 |
+
assert classes is not None, "Please provide class names with `classes=...`"
|
413 |
+
self.processor = [MultiCategoryProcessor(self, one_hot=True)]
|
414 |
+
self.loss_func = BCEWithLogitsFlat()
|
415 |
+
self.one_hot = one_hot
|
416 |
+
self.copy_new += ['one_hot']
|
417 |
+
|
418 |
+
def get(self, i):
|
419 |
+
o = self.items[i]
|
420 |
+
if o is None: return None
|
421 |
+
if self.one_hot: return self.reconstruct(o.astype(np.float32))
|
422 |
+
return MultiCategory(one_hot(o, self.c), [self.classes[p] for p in o], o)
|
423 |
+
|
424 |
+
def analyze_pred(self, pred, thresh:float=0.5):
|
425 |
+
return (pred >= thresh).float()
|
426 |
+
|
427 |
+
def reconstruct(self, t):
|
428 |
+
o = [i for i in range(self.c) if t[i] == 1.]
|
429 |
+
return MultiCategory(t, [self.classes[p] for p in o], o)
|
430 |
+
|
431 |
+
class FloatList(ItemList):
|
432 |
+
"`ItemList` suitable for storing the floats in items for regression. Will add a `log` if this flag is `True`."
|
433 |
+
def __init__(self, items:Iterator, log:bool=False, classes:Collection=None, **kwargs):
|
434 |
+
super().__init__(np.array(items, dtype=np.float32), **kwargs)
|
435 |
+
self.log = log
|
436 |
+
self.copy_new.append('log')
|
437 |
+
self.c = self.items.shape[1] if len(self.items.shape) > 1 else 1
|
438 |
+
self.loss_func = MSELossFlat()
|
439 |
+
|
440 |
+
def get(self, i):
|
441 |
+
o = super().get(i)
|
442 |
+
return FloatItem(np.log(o) if self.log else o)
|
443 |
+
|
444 |
+
def reconstruct(self,t): return FloatItem(t.numpy())
|
445 |
+
|
446 |
+
class ItemLists():
|
447 |
+
"An `ItemList` for each of `train` and `valid` (optional `test`)."
|
448 |
+
def __init__(self, path:PathOrStr, train:ItemList, valid:ItemList):
|
449 |
+
self.path,self.train,self.valid,self.test = Path(path),train,valid,None
|
450 |
+
if not self.train.ignore_empty and len(self.train.items) == 0:
|
451 |
+
warn("Your training set is empty. If this is by design, pass `ignore_empty=True` to remove this warning.")
|
452 |
+
if not self.valid.ignore_empty and len(self.valid.items) == 0:
|
453 |
+
warn("""Your validation set is empty. If this is by design, use `split_none()`
|
454 |
+
or pass `ignore_empty=True` when labelling to remove this warning.""")
|
455 |
+
if isinstance(self.train, LabelList): self.__class__ = LabelLists
|
456 |
+
|
457 |
+
def __dir__(self)->List[str]:
|
458 |
+
default_dir = dir(type(self)) + list(self.__dict__.keys())
|
459 |
+
add_ons = ['label_const', 'label_empty', 'label_from_df', 'label_from_folder', 'label_from_func',
|
460 |
+
'label_from_list', 'label_from_re']
|
461 |
+
return default_dir + add_ons
|
462 |
+
|
463 |
+
def __repr__(self)->str:
|
464 |
+
return f'{self.__class__.__name__};\n\nTrain: {self.train};\n\nValid: {self.valid};\n\nTest: {self.test}'
|
465 |
+
|
466 |
+
def __getattr__(self, k):
|
467 |
+
ft = getattr(self.train, k)
|
468 |
+
if not isinstance(ft, Callable): return ft
|
469 |
+
fv = getattr(self.valid, k)
|
470 |
+
assert isinstance(fv, Callable)
|
471 |
+
def _inner(*args, **kwargs):
|
472 |
+
self.train = ft(*args, from_item_lists=True, **kwargs)
|
473 |
+
assert isinstance(self.train, LabelList)
|
474 |
+
kwargs['label_cls'] = self.train.y.__class__
|
475 |
+
self.valid = fv(*args, from_item_lists=True, **kwargs)
|
476 |
+
self.__class__ = LabelLists
|
477 |
+
self.process()
|
478 |
+
return self
|
479 |
+
return _inner
|
480 |
+
|
481 |
+
def __setstate__(self,data:Any): self.__dict__.update(data)
|
482 |
+
|
483 |
+
@property
|
484 |
+
def lists(self):
|
485 |
+
res = [self.train,self.valid]
|
486 |
+
if self.test is not None: res.append(self.test)
|
487 |
+
return res
|
488 |
+
|
489 |
+
def label_from_lists(self, train_labels:Iterator, valid_labels:Iterator, label_cls:Callable=None, **kwargs)->'LabelList':
|
490 |
+
"Use the labels in `train_labels` and `valid_labels` to label the data. `label_cls` will overwrite the default."
|
491 |
+
label_cls = self.train.get_label_cls(train_labels, label_cls)
|
492 |
+
self.train = self.train._label_list(x=self.train, y=label_cls(train_labels, **kwargs))
|
493 |
+
self.valid = self.valid._label_list(x=self.valid, y=self.train.y.new(valid_labels, **kwargs))
|
494 |
+
self.__class__ = LabelLists
|
495 |
+
self.process()
|
496 |
+
return self
|
497 |
+
|
498 |
+
def transform(self, tfms:Optional[Tuple[TfmList,TfmList]]=(None,None), **kwargs):
|
499 |
+
"Set `tfms` to be applied to the xs of the train and validation set."
|
500 |
+
if not tfms: tfms=(None,None)
|
501 |
+
assert is_listy(tfms) and len(tfms) == 2, "Please pass a list of two lists of transforms (train and valid)."
|
502 |
+
self.train.transform(tfms[0], **kwargs)
|
503 |
+
self.valid.transform(tfms[1], **kwargs)
|
504 |
+
if self.test: self.test.transform(tfms[1], **kwargs)
|
505 |
+
return self
|
506 |
+
|
507 |
+
def transform_y(self, tfms:Optional[Tuple[TfmList,TfmList]]=(None,None), **kwargs):
|
508 |
+
"Set `tfms` to be applied to the ys of the train and validation set."
|
509 |
+
if not tfms: tfms=(None,None)
|
510 |
+
self.train.transform_y(tfms[0], **kwargs)
|
511 |
+
self.valid.transform_y(tfms[1], **kwargs)
|
512 |
+
if self.test: self.test.transform_y(tfms[1], **kwargs)
|
513 |
+
return self
|
514 |
+
|
515 |
+
def databunch(self, **kwargs):
|
516 |
+
"To throw a clear error message when the data wasn't labeled."
|
517 |
+
raise Exception("Your data isn't labeled, can't turn it into a `DataBunch` yet!")
|
518 |
+
|
519 |
+
class LabelLists(ItemLists):
|
520 |
+
"A `LabelList` for each of `train` and `valid` (optional `test`)."
|
521 |
+
def get_processors(self):
|
522 |
+
"Read the default class processors if none have been set."
|
523 |
+
procs_x,procs_y = listify(self.train.x._processor),listify(self.train.y._processor)
|
524 |
+
xp = ifnone(self.train.x.processor, [p(ds=self.train.x) for p in procs_x])
|
525 |
+
yp = ifnone(self.train.y.processor, [p(ds=self.train.y) for p in procs_y])
|
526 |
+
return xp,yp
|
527 |
+
|
528 |
+
def process(self):
|
529 |
+
"Process the inner datasets."
|
530 |
+
xp,yp = self.get_processors()
|
531 |
+
for ds,n in zip(self.lists, ['train','valid','test']): ds.process(xp, yp, name=n)
|
532 |
+
#progress_bar clear the outputs so in some case warnings issued during processing disappear.
|
533 |
+
for ds in self.lists:
|
534 |
+
if getattr(ds, 'warn', False): warn(ds.warn)
|
535 |
+
return self
|
536 |
+
|
537 |
+
def filter_by_func(self, func:Callable):
|
538 |
+
for ds in self.lists: ds.filter_by_func(func)
|
539 |
+
return self
|
540 |
+
|
541 |
+
def databunch(self, path:PathOrStr=None, bs:int=64, val_bs:int=None, num_workers:int=defaults.cpus,
|
542 |
+
dl_tfms:Optional[Collection[Callable]]=None, device:torch.device=None, collate_fn:Callable=data_collate,
|
543 |
+
no_check:bool=False, **kwargs)->'DataBunch':
|
544 |
+
"Create an `DataBunch` from self, `path` will override `self.path`, `kwargs` are passed to `DataBunch.create`."
|
545 |
+
path = Path(ifnone(path, self.path))
|
546 |
+
data = self.x._bunch.create(self.train, self.valid, test_ds=self.test, path=path, bs=bs, val_bs=val_bs,
|
547 |
+
num_workers=num_workers, dl_tfms=dl_tfms, device=device, collate_fn=collate_fn, no_check=no_check, **kwargs)
|
548 |
+
if getattr(self, 'normalize', False):#In case a normalization was serialized
|
549 |
+
norm = self.normalize
|
550 |
+
data.normalize((norm['mean'], norm['std']), do_x=norm['do_x'], do_y=norm['do_y'])
|
551 |
+
data.label_list = self
|
552 |
+
return data
|
553 |
+
|
554 |
+
def add_test(self, items:Iterator, label:Any=None, tfms=None, tfm_y=None):
|
555 |
+
"Add test set containing `items` with an arbitrary `label`."
|
556 |
+
# if no label passed, use label of first training item
|
557 |
+
if label is None: labels = EmptyLabelList([0] * len(items))
|
558 |
+
else: labels = self.valid.y.new([label] * len(items)).process()
|
559 |
+
if isinstance(items, MixedItemList): items = self.valid.x.new(items.item_lists, inner_df=items.inner_df).process()
|
560 |
+
elif isinstance(items, ItemList): items = self.valid.x.new(items.items, inner_df=items.inner_df).process()
|
561 |
+
else: items = self.valid.x.new(items).process()
|
562 |
+
self.test = self.valid.new(items, labels, tfms=tfms, tfm_y=tfm_y)
|
563 |
+
return self
|
564 |
+
|
565 |
+
def add_test_folder(self, test_folder:str='test', label:Any=None, tfms=None, tfm_y=None):
|
566 |
+
"Add test set containing items from `test_folder` and an arbitrary `label`."
|
567 |
+
# note: labels will be ignored if available in the test dataset
|
568 |
+
items = self.x.__class__.from_folder(self.path/test_folder)
|
569 |
+
return self.add_test(items.items, label=label, tfms=tfms, tfm_y=tfm_y)
|
570 |
+
|
571 |
+
@classmethod
|
572 |
+
def load_state(cls, path:PathOrStr, state:dict):
|
573 |
+
"Create a `LabelLists` with empty sets from the serialized `state`."
|
574 |
+
path = Path(path)
|
575 |
+
train_ds = LabelList.load_state(path, state)
|
576 |
+
valid_ds = LabelList.load_state(path, state)
|
577 |
+
return LabelLists(path, train=train_ds, valid=valid_ds)
|
578 |
+
|
579 |
+
@classmethod
|
580 |
+
def load_empty(cls, path:PathOrStr, fn:PathOrStr='export.pkl'):
|
581 |
+
"Create a `LabelLists` with empty sets from the serialized file in `path/fn`."
|
582 |
+
path = Path(path)
|
583 |
+
state = torch.load(open(path/fn, 'rb'))
|
584 |
+
return LabelLists.load_state(path, state)
|
585 |
+
|
586 |
+
def _check_kwargs(ds:ItemList, tfms:TfmList, **kwargs):
|
587 |
+
tfms = listify(tfms)
|
588 |
+
if (tfms is None or len(tfms) == 0) and len(kwargs) == 0: return
|
589 |
+
if len(ds.items) >= 1:
|
590 |
+
x = ds[0]
|
591 |
+
try: x.apply_tfms(tfms, **kwargs)
|
592 |
+
except Exception as e:
|
593 |
+
raise Exception(f"It's not possible to apply those transforms to your dataset:\n {e}")
|
594 |
+
|
595 |
+
class LabelList(Dataset):
|
596 |
+
"A list of inputs `x` and labels `y` with optional `tfms`."
|
597 |
+
def __init__(self, x:ItemList, y:ItemList, tfms:TfmList=None, tfm_y:bool=False, **kwargs):
|
598 |
+
self.x,self.y,self.tfm_y = x,y,tfm_y
|
599 |
+
self.y.x = x
|
600 |
+
self.item=None
|
601 |
+
self.transform(tfms, **kwargs)
|
602 |
+
|
603 |
+
def __len__(self)->int: return len(self.x) if self.item is None else 1
|
604 |
+
|
605 |
+
@contextmanager
|
606 |
+
def set_item(self,item):
|
607 |
+
"For inference, will briefly replace the dataset with one that only contains `item`."
|
608 |
+
self.item = self.x.process_one(item)
|
609 |
+
yield None
|
610 |
+
self.item = None
|
611 |
+
|
612 |
+
def __repr__(self)->str:
|
613 |
+
items = [self[i] for i in range(min(5,len(self.items)))]
|
614 |
+
res = f'{self.__class__.__name__} ({len(self.items)} items)\n'
|
615 |
+
res += f'x: {self.x.__class__.__name__}\n{show_some([i[0] for i in items])}\n'
|
616 |
+
res += f'y: {self.y.__class__.__name__}\n{show_some([i[1] for i in items])}\n'
|
617 |
+
return res + f'Path: {self.path}'
|
618 |
+
|
619 |
+
def predict(self, res):
|
620 |
+
"Delegates predict call on `res` to `self.y`."
|
621 |
+
return self.y.predict(res)
|
622 |
+
|
623 |
+
@property
|
624 |
+
def c(self): return self.y.c
|
625 |
+
|
626 |
+
def new(self, x, y, tfms=None, tfm_y=None, **kwargs)->'LabelList':
|
627 |
+
tfms,tfm_y = ifnone(tfms, self.tfms),ifnone(tfm_y, self.tfm_y)
|
628 |
+
if isinstance(x, ItemList):
|
629 |
+
return self.__class__(x, y, tfms=tfms, tfm_y=tfm_y, **self.tfmargs)
|
630 |
+
else:
|
631 |
+
return self.new(self.x.new(x, **kwargs), self.y.new(y, **kwargs), tfms=tfms, tfm_y=tfm_y).process()
|
632 |
+
|
633 |
+
def __getattr__(self,k:str)->Any:
|
634 |
+
x = super().__getattribute__('x')
|
635 |
+
res = getattr(x, k, None)
|
636 |
+
if res is not None and k not in ['classes', 'c']: return res
|
637 |
+
y = super().__getattribute__('y')
|
638 |
+
res = getattr(y, k, None)
|
639 |
+
if res is not None: return res
|
640 |
+
raise AttributeError(k)
|
641 |
+
|
642 |
+
def __setstate__(self,data:Any): self.__dict__.update(data)
|
643 |
+
|
644 |
+
def __getitem__(self,idxs:Union[int,np.ndarray])->'LabelList':
|
645 |
+
"return a single (x, y) if `idxs` is an integer or a new `LabelList` object if `idxs` is a range."
|
646 |
+
idxs = try_int(idxs)
|
647 |
+
if isinstance(idxs, Integral):
|
648 |
+
if self.item is None: x,y = self.x[idxs],self.y[idxs]
|
649 |
+
else: x,y = self.item ,0
|
650 |
+
if self.tfms or self.tfmargs:
|
651 |
+
x = x.apply_tfms(self.tfms, is_x=True, **self.tfmargs)
|
652 |
+
if hasattr(self, 'tfms_y') and self.tfm_y and self.item is None:
|
653 |
+
y = y.apply_tfms(self.tfms_y, is_x=False, **{**self.tfmargs_y, 'do_resolve':False})
|
654 |
+
if y is None: y=0
|
655 |
+
return x,y
|
656 |
+
else: return self.new(self.x[idxs], self.y[idxs])
|
657 |
+
|
658 |
+
def to_df(self)->None:
|
659 |
+
"Create `pd.DataFrame` containing `items` from `self.x` and `self.y`."
|
660 |
+
return pd.DataFrame(dict(x=self.x._relative_item_paths(), y=[str(o) for o in self.y]))
|
661 |
+
|
662 |
+
def to_csv(self, dest:str)->None:
|
663 |
+
"Save `self.to_df()` to a CSV file in `self.path`/`dest`."
|
664 |
+
self.to_df().to_csv(self.path/dest, index=False)
|
665 |
+
|
666 |
+
def get_state(self, **kwargs):
|
667 |
+
"Return the minimal state for export."
|
668 |
+
state = {'x_cls':self.x.__class__, 'x_proc':self.x.processor,
|
669 |
+
'y_cls':self.y.__class__, 'y_proc':self.y.processor,
|
670 |
+
'tfms':self.tfms, 'tfm_y':self.tfm_y, 'tfmargs':self.tfmargs}
|
671 |
+
if hasattr(self, 'tfms_y'): state['tfms_y'] = self.tfms_y
|
672 |
+
if hasattr(self, 'tfmargs_y'): state['tfmargs_y'] = self.tfmargs_y
|
673 |
+
return {**state, **kwargs}
|
674 |
+
|
675 |
+
def export(self, fn:PathOrStr, **kwargs):
|
676 |
+
"Export the minimal state and save it in `fn` to load an empty version for inference."
|
677 |
+
pickle.dump(self.get_state(**kwargs), open(fn, 'wb'))
|
678 |
+
|
679 |
+
@classmethod
|
680 |
+
def load_empty(cls, path:PathOrStr, fn:PathOrStr):
|
681 |
+
"Load the state in `fn` to create an empty `LabelList` for inference."
|
682 |
+
return cls.load_state(path, pickle.load(open(Path(path)/fn, 'rb')))
|
683 |
+
|
684 |
+
@classmethod
|
685 |
+
def load_state(cls, path:PathOrStr, state:dict) -> 'LabelList':
|
686 |
+
"Create a `LabelList` from `state`."
|
687 |
+
x = state['x_cls']([], path=path, processor=state['x_proc'], ignore_empty=True)
|
688 |
+
y = state['y_cls']([], path=path, processor=state['y_proc'], ignore_empty=True)
|
689 |
+
res = cls(x, y, tfms=state['tfms'], tfm_y=state['tfm_y'], **state['tfmargs']).process()
|
690 |
+
if state.get('tfms_y', False): res.tfms_y = state['tfms_y']
|
691 |
+
if state.get('tfmargs_y', False): res.tfmargs_y = state['tfmargs_y']
|
692 |
+
if state.get('normalize', False): res.normalize = state['normalize']
|
693 |
+
return res
|
694 |
+
|
695 |
+
def process(self, xp:PreProcessor=None, yp:PreProcessor=None, name:str=None):
|
696 |
+
"Launch the processing on `self.x` and `self.y` with `xp` and `yp`."
|
697 |
+
self.y.process(yp)
|
698 |
+
if getattr(self.y, 'filter_missing_y', False):
|
699 |
+
filt = array([o is None for o in self.y.items])
|
700 |
+
if filt.sum()>0:
|
701 |
+
#Warnings are given later since progress_bar might make them disappear.
|
702 |
+
self.warn = f"You are labelling your items with {self.y.__class__.__name__}.\n"
|
703 |
+
self.warn += f"Your {name} set contained the following unknown labels, the corresponding items have been discarded.\n"
|
704 |
+
for p in self.y.processor:
|
705 |
+
if len(getattr(p, 'warns', [])) > 0:
|
706 |
+
warnings = list(set(p.warns))
|
707 |
+
self.warn += ', '.join(warnings[:5])
|
708 |
+
if len(warnings) > 5: self.warn += "..."
|
709 |
+
p.warns = []
|
710 |
+
self.x,self.y = self.x[~filt],self.y[~filt]
|
711 |
+
self.x.process(xp)
|
712 |
+
return self
|
713 |
+
|
714 |
+
def filter_by_func(self, func:Callable):
|
715 |
+
filt = array([func(x,y) for x,y in zip(self.x.items, self.y.items)])
|
716 |
+
self.x,self.y = self.x[~filt],self.y[~filt]
|
717 |
+
return self
|
718 |
+
|
719 |
+
def transform(self, tfms:TfmList, tfm_y:bool=None, **kwargs):
|
720 |
+
"Set the `tfms` and `tfm_y` value to be applied to the inputs and targets."
|
721 |
+
_check_kwargs(self.x, tfms, **kwargs)
|
722 |
+
if tfm_y is None: tfm_y = self.tfm_y
|
723 |
+
tfms_y = None if tfms is None else list(filter(lambda t: getattr(t, 'use_on_y', True), listify(tfms)))
|
724 |
+
if tfm_y: _check_kwargs(self.y, tfms_y, **kwargs)
|
725 |
+
self.tfms,self.tfmargs = tfms,kwargs
|
726 |
+
self.tfm_y,self.tfms_y,self.tfmargs_y = tfm_y,tfms_y,kwargs
|
727 |
+
return self
|
728 |
+
|
729 |
+
def transform_y(self, tfms:TfmList=None, **kwargs):
|
730 |
+
"Set `tfms` to be applied to the targets only."
|
731 |
+
tfms_y = list(filter(lambda t: getattr(t, 'use_on_y', True), listify(self.tfms if tfms is None else tfms)))
|
732 |
+
tfmargs_y = {**self.tfmargs, **kwargs} if tfms is None else kwargs
|
733 |
+
_check_kwargs(self.y, tfms_y, **tfmargs_y)
|
734 |
+
self.tfm_y,self.tfms_y,self.tfmargs_y=True,tfms_y,tfmargs_y
|
735 |
+
return self
|
736 |
+
|
737 |
+
def databunch(self, **kwargs):
|
738 |
+
"To throw a clear error message when the data wasn't split."
|
739 |
+
raise Exception("Your data isn't split, if you don't want a validation set, please use `split_none`")
|
740 |
+
|
741 |
+
@classmethod
|
742 |
+
def _databunch_load_empty(cls, path, fname:str='export.pkl'):
|
743 |
+
"Load an empty `DataBunch` from the exported file in `path/fname` with optional `tfms`."
|
744 |
+
sd = LabelLists.load_empty(path, fn=fname)
|
745 |
+
return sd.databunch()
|
746 |
+
|
747 |
+
DataBunch.load_empty = _databunch_load_empty
|
748 |
+
|
749 |
+
class MixedProcessor(PreProcessor):
|
750 |
+
def __init__(self, procs:Collection[Union[PreProcessor, Collection[PreProcessor]]]):
|
751 |
+
self.procs = procs
|
752 |
+
|
753 |
+
def process_one(self, item:Any):
|
754 |
+
res = []
|
755 |
+
for procs, i in zip(self.procs, item):
|
756 |
+
for p in procs: i = p.process_one(i)
|
757 |
+
res.append(i)
|
758 |
+
return res
|
759 |
+
|
760 |
+
def process(self, ds:Collection):
|
761 |
+
for procs, il in zip(self.procs, ds.item_lists):
|
762 |
+
for p in procs: p.process(il)
|
763 |
+
|
764 |
+
class MixedItem(ItemBase):
|
765 |
+
def __init__(self, items):
|
766 |
+
self.obj = items
|
767 |
+
self.data = [item.data for item in items]
|
768 |
+
|
769 |
+
def __repr__(self): return '\n'.join([f'{self.__class__.__name__}'] + [repr(item) for item in self.obj])
|
770 |
+
|
771 |
+
def apply_tfms(self, tfms:Collection, **kwargs):
|
772 |
+
self.obj = [item.apply_tfms(t, **kwargs) for item,t in zip(self.obj, tfms)]
|
773 |
+
self.data = [item.data for item in self.obj]
|
774 |
+
return self
|
775 |
+
|
776 |
+
class MixedItemList(ItemList):
|
777 |
+
|
778 |
+
def __init__(self, item_lists, path:PathOrStr=None, label_cls:Callable=None, inner_df:Any=None,
|
779 |
+
x:'ItemList'=None, ignore_empty:bool=False, processor=None):
|
780 |
+
self.item_lists = item_lists
|
781 |
+
if processor is None:
|
782 |
+
default_procs = [[p(ds=il) for p in listify(il._processor)] for il in item_lists]
|
783 |
+
processor = MixedProcessor([ifnone(il.processor, dp) for il,dp in zip(item_lists, default_procs)])
|
784 |
+
items = range_of(item_lists[0]) if len(item_lists) >= 1 else []
|
785 |
+
if path is None and len(item_lists) >= 1: path = item_lists[0].path
|
786 |
+
super().__init__(items, processor=processor, path=path,
|
787 |
+
label_cls=label_cls, inner_df=inner_df, x=x, ignore_empty=ignore_empty)
|
788 |
+
|
789 |
+
def new(self, item_lists, processor:PreProcessor=None, **kwargs)->'ItemList':
|
790 |
+
"Create a new `ItemList` from `items`, keeping the same attributes."
|
791 |
+
processor = ifnone(processor, self.processor)
|
792 |
+
copy_d = {o:getattr(self,o) for o in self.copy_new}
|
793 |
+
kwargs = {**copy_d, **kwargs}
|
794 |
+
return self.__class__(item_lists, processor=processor, **kwargs)
|
795 |
+
|
796 |
+
def get(self, i):
|
797 |
+
return MixedItem([il.get(i) for il in self.item_lists])
|
798 |
+
|
799 |
+
def __getitem__(self,idxs:int)->Any:
|
800 |
+
idxs = try_int(idxs)
|
801 |
+
if isinstance(idxs, Integral): return self.get(idxs)
|
802 |
+
else:
|
803 |
+
item_lists = [il.new(il.items[idxs], inner_df=index_row(il.inner_df, idxs)) for il in self.item_lists]
|
804 |
+
return self.new(item_lists, inner_df=index_row(self.inner_df, idxs))
|
DeOldify/fastai/datasets.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .core import *
|
2 |
+
import hashlib
|
3 |
+
|
4 |
+
__all__ = ['URLs', 'Config', 'untar_data', 'download_data', 'datapath4file', 'url2name', 'url2path']
|
5 |
+
|
6 |
+
MODEL_URL = 'http://files.fast.ai/models/'
|
7 |
+
URL = 'http://files.fast.ai/data/examples/'
|
8 |
+
class URLs():
|
9 |
+
"Global constants for dataset and model URLs."
|
10 |
+
LOCAL_PATH = Path.cwd()
|
11 |
+
S3 = 'https://s3.amazonaws.com/fast-ai-'
|
12 |
+
|
13 |
+
S3_IMAGE = f'{S3}imageclas/'
|
14 |
+
S3_IMAGELOC = f'{S3}imagelocal/'
|
15 |
+
S3_NLP = f'{S3}nlp/'
|
16 |
+
S3_COCO = f'{S3}coco/'
|
17 |
+
S3_MODEL = f'{S3}modelzoo/'
|
18 |
+
|
19 |
+
# main datasets
|
20 |
+
ADULT_SAMPLE = f'{URL}adult_sample'
|
21 |
+
BIWI_SAMPLE = f'{URL}biwi_sample'
|
22 |
+
CIFAR = f'{URL}cifar10'
|
23 |
+
COCO_SAMPLE = f'{S3_COCO}coco_sample'
|
24 |
+
COCO_TINY = f'{URL}coco_tiny'
|
25 |
+
HUMAN_NUMBERS = f'{URL}human_numbers'
|
26 |
+
IMDB = f'{S3_NLP}imdb'
|
27 |
+
IMDB_SAMPLE = f'{URL}imdb_sample'
|
28 |
+
ML_SAMPLE = f'{URL}movie_lens_sample'
|
29 |
+
MNIST_SAMPLE = f'{URL}mnist_sample'
|
30 |
+
MNIST_TINY = f'{URL}mnist_tiny'
|
31 |
+
MNIST_VAR_SIZE_TINY = f'{S3_IMAGE}mnist_var_size_tiny'
|
32 |
+
PLANET_SAMPLE = f'{URL}planet_sample'
|
33 |
+
PLANET_TINY = f'{URL}planet_tiny'
|
34 |
+
IMAGENETTE = f'{S3_IMAGE}imagenette'
|
35 |
+
IMAGENETTE_160 = f'{S3_IMAGE}imagenette-160'
|
36 |
+
IMAGENETTE_320 = f'{S3_IMAGE}imagenette-320'
|
37 |
+
IMAGEWOOF = f'{S3_IMAGE}imagewoof'
|
38 |
+
IMAGEWOOF_160 = f'{S3_IMAGE}imagewoof-160'
|
39 |
+
IMAGEWOOF_320 = f'{S3_IMAGE}imagewoof-320'
|
40 |
+
|
41 |
+
# kaggle competitions download dogs-vs-cats -p {DOGS.absolute()}
|
42 |
+
DOGS = f'{URL}dogscats'
|
43 |
+
|
44 |
+
# image classification datasets
|
45 |
+
CALTECH_101 = f'{S3_IMAGE}caltech_101'
|
46 |
+
CARS = f'{S3_IMAGE}stanford-cars'
|
47 |
+
CIFAR_100 = f'{S3_IMAGE}cifar100'
|
48 |
+
CUB_200_2011 = f'{S3_IMAGE}CUB_200_2011'
|
49 |
+
FLOWERS = f'{S3_IMAGE}oxford-102-flowers'
|
50 |
+
FOOD = f'{S3_IMAGE}food-101'
|
51 |
+
MNIST = f'{S3_IMAGE}mnist_png'
|
52 |
+
PETS = f'{S3_IMAGE}oxford-iiit-pet'
|
53 |
+
|
54 |
+
# NLP datasets
|
55 |
+
AG_NEWS = f'{S3_NLP}ag_news_csv'
|
56 |
+
AMAZON_REVIEWS = f'{S3_NLP}amazon_review_full_csv'
|
57 |
+
AMAZON_REVIEWS_POLARITY = f'{S3_NLP}amazon_review_polarity_csv'
|
58 |
+
DBPEDIA = f'{S3_NLP}dbpedia_csv'
|
59 |
+
MT_ENG_FRA = f'{S3_NLP}giga-fren'
|
60 |
+
SOGOU_NEWS = f'{S3_NLP}sogou_news_csv'
|
61 |
+
WIKITEXT = f'{S3_NLP}wikitext-103'
|
62 |
+
WIKITEXT_TINY = f'{S3_NLP}wikitext-2'
|
63 |
+
YAHOO_ANSWERS = f'{S3_NLP}yahoo_answers_csv'
|
64 |
+
YELP_REVIEWS = f'{S3_NLP}yelp_review_full_csv'
|
65 |
+
YELP_REVIEWS_POLARITY = f'{S3_NLP}yelp_review_polarity_csv'
|
66 |
+
|
67 |
+
# Image localization datasets
|
68 |
+
BIWI_HEAD_POSE = f"{S3_IMAGELOC}biwi_head_pose"
|
69 |
+
CAMVID = f'{S3_IMAGELOC}camvid'
|
70 |
+
CAMVID_TINY = f'{URL}camvid_tiny'
|
71 |
+
LSUN_BEDROOMS = f'{S3_IMAGE}bedroom'
|
72 |
+
PASCAL_2007 = f'{S3_IMAGELOC}pascal_2007'
|
73 |
+
PASCAL_2012 = f'{S3_IMAGELOC}pascal_2012'
|
74 |
+
|
75 |
+
#Pretrained models
|
76 |
+
OPENAI_TRANSFORMER = f'{S3_MODEL}transformer'
|
77 |
+
WT103_FWD = f'{S3_MODEL}wt103-fwd'
|
78 |
+
WT103_BWD = f'{S3_MODEL}wt103-bwd'
|
79 |
+
|
80 |
+
# to create/update a checksum for ./mnist_var_size_tiny.tgz, run:
|
81 |
+
# python -c 'import fastai.datasets; print(fastai.datasets._check_file("mnist_var_size_tiny.tgz"))'
|
82 |
+
_checks = {
|
83 |
+
URLs.ADULT_SAMPLE:(968212, '64eb9d7e23732de0b138f7372d15492f'),
|
84 |
+
URLs.AG_NEWS:(11784419, 'b86f328f4dbd072486591cb7a5644dcd'),
|
85 |
+
URLs.AMAZON_REVIEWS_POLARITY:(688339454, '676f7e5208ec343c8274b4bb085bc938'),
|
86 |
+
URLs.AMAZON_REVIEWS:(643695014, '4a1196cf0adaea22f4bc3f592cddde90'),
|
87 |
+
URLs.BIWI_HEAD_POSE:(452316199, '00f4ccf66e8cba184bc292fdc08fb237'),
|
88 |
+
URLs.BIWI_SAMPLE:(593774, '9179f4c1435f4b291f0d5b072d60c2c9'),
|
89 |
+
URLs.CALTECH_101:(131740031, 'd673425306e98ee4619fcdeef8a0e876'),
|
90 |
+
URLs.CAMVID:(598913237, '648371e4f3a833682afb39b08a3ce2aa'),
|
91 |
+
URLs.CAMVID_TINY:(2314212, '2cf6daf91b7a2083ecfa3e9968e9d915'),
|
92 |
+
URLs.CARS:(1957803273, '9045d6673c9ced0889f41816f6bf2f9f'),
|
93 |
+
URLs.CIFAR:(168168549, 'a5f8c31371b63a406b23368042812d3c'),
|
94 |
+
URLs.CIFAR_100:(169168619, 'e5e65dcb54b9d3913f7b8a9ad6607e62'),
|
95 |
+
URLs.COCO_SAMPLE:(3245877008, '006cd55d633d94b36ecaf661467830ec'),
|
96 |
+
URLs.COCO_TINY:(801038, '367467451ac4fba79a647753c2c66d3a'),
|
97 |
+
URLs.CUB_200_2011:(1150585339, 'd2acaa99439dff0483c7bbac1bfe2a92'),
|
98 |
+
URLs.DBPEDIA:(68341743, '239c7837b9e79db34486f3de6a00e38e'),
|
99 |
+
URLs.DOGS:(839285364, '3e483c8d6ef2175e9d395a6027eb92b7'),
|
100 |
+
URLs.FLOWERS:(345236087, '5666e01c1311b4c67fcf20d2b3850a88'),
|
101 |
+
URLs.FOOD:(5686607260, '1a540ebf1fb40b2bf3f2294234ba7907'),
|
102 |
+
URLs.HUMAN_NUMBERS:(30252, '8a19c3bfa2bcb08cd787e741261f3ea2'),
|
103 |
+
URLs.IMDB:(144440600, '90f9b1c4ff43a90d67553c9240dc0249'),
|
104 |
+
URLs.IMDB_SAMPLE:(571827, '0842e61a9867caa2e6fbdb14fa703d61'),
|
105 |
+
URLs.LSUN_BEDROOMS:(4579163978, '35d84f38f8a15fe47e66e460c8800d68'),
|
106 |
+
URLs.ML_SAMPLE:(51790, '10961384dfe7c5181460390a460c1f77'),
|
107 |
+
URLs.MNIST:(15683414, '03639f83c4e3d19e0a3a53a8a997c487'),
|
108 |
+
URLs.MNIST_SAMPLE:(3214948, '2dbc7ec6f9259b583af0072c55816a88'),
|
109 |
+
URLs.MNIST_TINY:(342207, '56143e8f24db90d925d82a5a74141875'),
|
110 |
+
URLs.MNIST_VAR_SIZE_TINY:(565372, 'b71a930f4eb744a4a143a6c7ff7ed67f'),
|
111 |
+
URLs.MT_ENG_FRA:(2598183296, '69573f58e2c850b90f2f954077041d8c'),
|
112 |
+
URLs.OPENAI_TRANSFORMER:(432848315, '024b0d2203ebb0cd1fc64b27cf8af18e'),
|
113 |
+
URLs.PASCAL_2007:(1636130334, 'a70574e9bc592bd3b253f5bf46ce12e3'),
|
114 |
+
URLs.PASCAL_2012:(2611715776, '2ae7897038383836f86ce58f66b09e31'),
|
115 |
+
URLs.PETS:(811706944, 'e4db5c768afd933bb91f5f594d7417a4'),
|
116 |
+
URLs.PLANET_SAMPLE:(15523994, '8bfb174b3162f07fbde09b54555bdb00'),
|
117 |
+
URLs.PLANET_TINY:(997569, '490873c5683454d4b2611fb1f00a68a9'),
|
118 |
+
URLs.SOGOU_NEWS:(384269937, '950f1366d33be52f5b944f8a8b680902'),
|
119 |
+
URLs.WIKITEXT:(190200704, '2dd8cf8693b3d27e9c8f0a7df054b2c7'),
|
120 |
+
URLs.WIKITEXT_TINY:(4070055, '2a82d47a7b85c8b6a8e068dc4c1d37e7'),
|
121 |
+
URLs.WT103_FWD:(105067061, '7d1114cd9684bf9d1ca3c9f6a54da6f9'),
|
122 |
+
URLs.WT103_BWD:(105205312, '20b06f5830fd5a891d21044c28d3097f'),
|
123 |
+
URLs.YAHOO_ANSWERS:(319476345, '0632a0d236ef3a529c0fa4429b339f68'),
|
124 |
+
URLs.YELP_REVIEWS_POLARITY:(166373201, '48c8451c1ad30472334d856b5d294807'),
|
125 |
+
URLs.YELP_REVIEWS:(196146755, '1efd84215ea3e30d90e4c33764b889db'),
|
126 |
+
}
|
127 |
+
|
128 |
+
#TODO: This can probably be coded more shortly and nicely.
|
129 |
+
class Config():
|
130 |
+
"Creates a default config file 'config.yml' in $FASTAI_HOME (default `~/.fastai/`)"
|
131 |
+
DEFAULT_CONFIG_LOCATION = os.path.expanduser(os.getenv('FASTAI_HOME', '~/.fastai'))
|
132 |
+
DEFAULT_CONFIG_PATH = DEFAULT_CONFIG_LOCATION + '/config.yml'
|
133 |
+
DEFAULT_CONFIG = {
|
134 |
+
'data_path': DEFAULT_CONFIG_LOCATION + '/data',
|
135 |
+
'data_archive_path': DEFAULT_CONFIG_LOCATION + '/data',
|
136 |
+
'model_path': DEFAULT_CONFIG_LOCATION + '/models'
|
137 |
+
}
|
138 |
+
|
139 |
+
@classmethod
|
140 |
+
def get_key(cls, key):
|
141 |
+
"Get the path to `key` in the config file."
|
142 |
+
return cls.get().get(key, cls.DEFAULT_CONFIG.get(key,None))
|
143 |
+
|
144 |
+
@classmethod
|
145 |
+
def get_path(cls, path):
|
146 |
+
"Get the `path` in the config file."
|
147 |
+
return _expand_path(cls.get_key(path))
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
def data_path(cls):
|
151 |
+
"Get the path to data in the config file."
|
152 |
+
return cls.get_path('data_path')
|
153 |
+
|
154 |
+
@classmethod
|
155 |
+
def data_archive_path(cls):
|
156 |
+
"Get the path to data archives in the config file."
|
157 |
+
return cls.get_path('data_archive_path')
|
158 |
+
|
159 |
+
@classmethod
|
160 |
+
def model_path(cls):
|
161 |
+
"Get the path to fastai pretrained models in the config file."
|
162 |
+
return cls.get_path('model_path')
|
163 |
+
|
164 |
+
@classmethod
|
165 |
+
def get(cls, fpath=None, create_missing=True):
|
166 |
+
"Retrieve the `Config` in `fpath`."
|
167 |
+
fpath = _expand_path(fpath or cls.DEFAULT_CONFIG_PATH)
|
168 |
+
if not fpath.exists() and create_missing: cls.create(fpath)
|
169 |
+
assert fpath.exists(), f'Could not find config at: {fpath}. Please create'
|
170 |
+
with open(fpath, 'r') as yaml_file: return yaml.safe_load(yaml_file)
|
171 |
+
|
172 |
+
@classmethod
|
173 |
+
def create(cls, fpath):
|
174 |
+
"Creates a `Config` from `fpath`."
|
175 |
+
fpath = _expand_path(fpath)
|
176 |
+
assert(fpath.suffix == '.yml')
|
177 |
+
if fpath.exists(): return
|
178 |
+
fpath.parent.mkdir(parents=True, exist_ok=True)
|
179 |
+
with open(fpath, 'w') as yaml_file:
|
180 |
+
yaml.dump(cls.DEFAULT_CONFIG, yaml_file, default_flow_style=False)
|
181 |
+
|
182 |
+
def _expand_path(fpath): return Path(fpath).expanduser()
|
183 |
+
def url2name(url): return url.split('/')[-1]
|
184 |
+
|
185 |
+
#TODO: simplify this mess
|
186 |
+
def url2path(url, data=True, ext:str='.tgz'):
|
187 |
+
"Change `url` to a path."
|
188 |
+
name = url2name(url)
|
189 |
+
return datapath4file(name, ext=ext, archive=False) if data else modelpath4file(name, ext=ext)
|
190 |
+
def _url2tgz(url, data=True, ext:str='.tgz'):
|
191 |
+
return datapath4file(f'{url2name(url)}{ext}', ext=ext) if data else modelpath4file(f'{url2name(url)}{ext}', ext=ext)
|
192 |
+
|
193 |
+
def modelpath4file(filename, ext:str='.tgz'):
|
194 |
+
"Return model path to `filename`, checking locally first then in the config file."
|
195 |
+
local_path = URLs.LOCAL_PATH/'models'/filename
|
196 |
+
if local_path.exists() or local_path.with_suffix(ext).exists(): return local_path
|
197 |
+
else: return Config.model_path()/filename
|
198 |
+
|
199 |
+
def datapath4file(filename, ext:str='.tgz', archive=True):
|
200 |
+
"Return data path to `filename`, checking locally first then in the config file."
|
201 |
+
local_path = URLs.LOCAL_PATH/'data'/filename
|
202 |
+
if local_path.exists() or local_path.with_suffix(ext).exists(): return local_path
|
203 |
+
elif archive: return Config.data_archive_path() / filename
|
204 |
+
else: return Config.data_path() / filename
|
205 |
+
|
206 |
+
def download_data(url:str, fname:PathOrStr=None, data:bool=True, ext:str='.tgz') -> Path:
|
207 |
+
"Download `url` to destination `fname`."
|
208 |
+
fname = Path(ifnone(fname, _url2tgz(url, data, ext=ext)))
|
209 |
+
os.makedirs(fname.parent, exist_ok=True)
|
210 |
+
if not fname.exists():
|
211 |
+
print(f'Downloading {url}')
|
212 |
+
download_url(f'{url}{ext}', fname)
|
213 |
+
return fname
|
214 |
+
|
215 |
+
def _check_file(fname):
|
216 |
+
size = os.path.getsize(fname)
|
217 |
+
with open(fname, "rb") as f:
|
218 |
+
hash_nb = hashlib.md5(f.read(2**20)).hexdigest()
|
219 |
+
return size,hash_nb
|
220 |
+
|
221 |
+
def untar_data(url:str, fname:PathOrStr=None, dest:PathOrStr=None, data=True, force_download=False) -> Path:
|
222 |
+
"Download `url` to `fname` if `dest` doesn't exist, and un-tgz to folder `dest`."
|
223 |
+
dest = url2path(url, data) if dest is None else Path(dest)/url2name(url)
|
224 |
+
fname = Path(ifnone(fname, _url2tgz(url, data)))
|
225 |
+
if force_download or (fname.exists() and url in _checks and _check_file(fname) != _checks[url]):
|
226 |
+
print(f"A new version of the {'dataset' if data else 'model'} is available.")
|
227 |
+
if fname.exists(): os.remove(fname)
|
228 |
+
if dest.exists(): shutil.rmtree(dest)
|
229 |
+
if not dest.exists():
|
230 |
+
fname = download_data(url, fname=fname, data=data)
|
231 |
+
if url in _checks:
|
232 |
+
assert _check_file(fname) == _checks[url], f"Downloaded file {fname} does not match checksum expected! Remove that file from {Config().data_archive_path()} and try your code again."
|
233 |
+
tarfile.open(fname, 'r:gz').extractall(dest.parent)
|
234 |
+
return dest
|
DeOldify/fastai/distributed.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .torch_core import *
|
2 |
+
from .basic_train import Learner,LearnerCallback
|
3 |
+
from torch.nn.parallel import DistributedDataParallel, DataParallel
|
4 |
+
from torch.utils.data.distributed import DistributedSampler
|
5 |
+
|
6 |
+
from fastai.text import TextLMDataBunch
|
7 |
+
|
8 |
+
__all__ = ['DistributedRecorder', 'DistributedTrainer', 'read_metrics', 'setup_distrib']
|
9 |
+
|
10 |
+
def rnn_reset(self):
|
11 |
+
if hasattr(self.module, 'reset'): self.module.reset()
|
12 |
+
DistributedDataParallel.reset = rnn_reset
|
13 |
+
|
14 |
+
class ParallelTrainer(LearnerCallback):
|
15 |
+
_order = -20
|
16 |
+
def on_train_begin(self, **kwargs): self.learn.model = DataParallel(self.learn.model)
|
17 |
+
def on_train_end (self, **kwargs): self.learn.model = self.learn.model.module
|
18 |
+
|
19 |
+
class DistributedTrainer(LearnerCallback):
|
20 |
+
_order = -20 # Needs to run before the recorder
|
21 |
+
def __init__(self, learn:Learner, cuda_id:int=0):
|
22 |
+
super().__init__(learn)
|
23 |
+
self.cuda_id,self.train_sampler = cuda_id,None
|
24 |
+
|
25 |
+
def _change_dl(self, dl, shuffle):
|
26 |
+
old_dl = dl
|
27 |
+
sampler = OurDistributedSampler(dl.dataset, shuffle=shuffle)
|
28 |
+
new_dl = dl.new(shuffle=False, sampler=sampler)
|
29 |
+
return old_dl,new_dl,sampler
|
30 |
+
|
31 |
+
def on_train_begin(self, **kwargs):
|
32 |
+
self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id)
|
33 |
+
shuffle = self.data.train_dl.init_kwargs['shuffle'] if hasattr(self.data.train_dl, 'init_kwargs') else True
|
34 |
+
self.old_train_dl,self.data.train_dl,self.train_sampler = self._change_dl(self.data.train_dl, shuffle)
|
35 |
+
if hasattr(self.data, 'valid_dl') and self.data.valid_dl is not None:
|
36 |
+
self.old_valid_dl,self.data.valid_dl,self.valid_sampler = self._change_dl(self.data.valid_dl, shuffle)
|
37 |
+
self.rank = rank_distrib()
|
38 |
+
self.recorder.silent = (self.rank != 0)
|
39 |
+
|
40 |
+
def on_epoch_begin(self, epoch, **kwargs): self.train_sampler.set_epoch(epoch)
|
41 |
+
|
42 |
+
def on_train_end(self, **kwargs):
|
43 |
+
self.learn.model = self.learn.model.module
|
44 |
+
self.learn.data.train_dl = self.old_train_dl
|
45 |
+
if hasattr(self.learn.data, 'valid_dl') and self.learn.data.valid_dl is not None:
|
46 |
+
self.learn.data.valid_dl = self.old_valid_dl
|
47 |
+
|
48 |
+
class DistributedRecorder(LearnerCallback):
|
49 |
+
def __init__(self, learn:Learner, cuda_id:int=0, cache_dir:PathOrStr='tmp'):
|
50 |
+
super().__init__(learn)
|
51 |
+
self.cuda_id,self.cache_dir = cuda_id,cache_dir
|
52 |
+
|
53 |
+
def on_train_begin(self, **kwargs):
|
54 |
+
os.makedirs(self.learn.path/self.cache_dir, exist_ok=True)
|
55 |
+
|
56 |
+
def on_epoch_end(self, **kwargs): self.save_stats()
|
57 |
+
def on_train_end(self, **kwargs): self.save_stats()
|
58 |
+
|
59 |
+
def save_stats(self):
|
60 |
+
cache_path,recorder = self.learn.path/self.cache_dir,self.learn.recorder
|
61 |
+
np.save(cache_path/f'losses_{self.cuda_id}', np.array(recorder.losses))
|
62 |
+
stats = np.array([[v] + m for v,m in zip(recorder.val_losses,recorder.metrics)])
|
63 |
+
np.save(cache_path/f'metrics_{self.cuda_id}', stats)
|
64 |
+
|
65 |
+
def _learner_parallel(learn:Learner):
|
66 |
+
"Use nn.DataParallel when training and remove when done"
|
67 |
+
if not torch.cuda.is_available(): warnings.warn('CUDA is not available, check your drivers - training will continue on CPU', ResourceWarning)
|
68 |
+
learn.callbacks.append(ParallelTrainer(learn))
|
69 |
+
return learn
|
70 |
+
|
71 |
+
def _learner_distributed(learn:Learner, cuda_id:int, cache_dir:PathOrStr='tmp'):
|
72 |
+
"Put `learn` on distributed training with `cuda_id`."
|
73 |
+
learn.callbacks.append(DistributedTrainer(learn, cuda_id))
|
74 |
+
learn.callbacks.append(DistributedRecorder(learn, cuda_id, cache_dir))
|
75 |
+
return learn
|
76 |
+
|
77 |
+
Learner.to_distributed = _learner_distributed
|
78 |
+
Learner.to_parallel = _learner_parallel
|
79 |
+
|
80 |
+
def read_metrics(cache_path:PathOrStr, n_gpus:int, reduce:bool=True):
|
81 |
+
losses,metrics = [],[]
|
82 |
+
for i in range(n_gpus):
|
83 |
+
losses.append(np.load(cache_path/f'losses_{i}.npy')[None])
|
84 |
+
metrics.append(np.load(cache_path/f'metrics_{i}.npy')[None])
|
85 |
+
if reduce:
|
86 |
+
losses,metrics = np.concatenate(losses,0),np.concatenate(metrics,0)
|
87 |
+
return losses.mean(0),metrics.mean(0)
|
88 |
+
return losses,metrics
|
89 |
+
|
90 |
+
def setup_distrib(gpu:Any=None):
|
91 |
+
if gpu is None: return gpu
|
92 |
+
gpu = int(gpu)
|
93 |
+
torch.cuda.set_device(int(gpu))
|
94 |
+
if num_distrib() > 1:
|
95 |
+
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
96 |
+
return gpu
|
97 |
+
|
98 |
+
class OurDistributedSampler(DistributedSampler):
|
99 |
+
"A sampler for language models with the option to not shuffle."
|
100 |
+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
|
101 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
|
102 |
+
self.shuffle = shuffle
|
103 |
+
|
104 |
+
def __iter__(self):
|
105 |
+
if self.shuffle:
|
106 |
+
g = torch.Generator()
|
107 |
+
g.manual_seed(self.epoch)
|
108 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
109 |
+
else: indices = torch.arange(len(self.dataset)).tolist()
|
110 |
+
|
111 |
+
# add extra samples to make it evenly divisible
|
112 |
+
indices += indices[:(self.total_size - len(indices))]
|
113 |
+
assert len(indices) == self.total_size
|
114 |
+
|
115 |
+
# subsample
|
116 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
117 |
+
assert len(indices) == self.num_samples
|
118 |
+
|
119 |
+
return iter(indices)
|
DeOldify/fastai/gen_doc/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import gen_notebooks, nbdoc, core, doctest, nbtest
|
DeOldify/fastai/gen_doc/autogen.tpl
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!--
|
2 |
+
|
3 |
+
|
4 |
+
#################################################
|
5 |
+
### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
|
6 |
+
#################################################
|
7 |
+
# file to edit: {{ resources.nb_path }}
|
8 |
+
# instructions: https://docs.fast.ai/gen_doc_main.html
|
9 |
+
|
10 |
+
-->
|
DeOldify/fastai/gen_doc/convert2html.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path, re, nbformat, jupyter_contrib_nbextensions
|
2 |
+
from nbconvert.preprocessors import Preprocessor
|
3 |
+
from nbconvert import HTMLExporter
|
4 |
+
from traitlets.config import Config
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
__all__ = ['read_nb', 'convert_nb', 'convert_all']
|
8 |
+
|
9 |
+
exporter = HTMLExporter(Config())
|
10 |
+
exporter.exclude_input_prompt=True
|
11 |
+
exporter.exclude_output_prompt=True
|
12 |
+
#Loads the template to deal with hidden cells.
|
13 |
+
exporter.template_file = 'jekyll.tpl'
|
14 |
+
path = Path(__file__).parent
|
15 |
+
exporter.template_path.append(str(path))
|
16 |
+
|
17 |
+
def read_nb(fname):
|
18 |
+
"Read the notebook in `fname`."
|
19 |
+
with open(fname,'r') as f: return nbformat.reads(f.read(), as_version=4)
|
20 |
+
|
21 |
+
def convert_nb(fname, dest_path='.'):
|
22 |
+
"Convert a notebook `fname` to html file in `dest_path`."
|
23 |
+
from .gen_notebooks import remove_undoc_cells, remove_code_cell_jupyter_widget_state_elem
|
24 |
+
nb = read_nb(fname)
|
25 |
+
nb['cells'] = remove_undoc_cells(nb['cells'])
|
26 |
+
nb['cells'] = remove_code_cell_jupyter_widget_state_elem(nb['cells'])
|
27 |
+
fname = Path(fname).absolute()
|
28 |
+
dest_name = fname.with_suffix('.html').name
|
29 |
+
meta = nb['metadata']
|
30 |
+
meta_jekyll = meta['jekyll'] if 'jekyll' in meta else {'title': fname.with_suffix('').name}
|
31 |
+
meta_jekyll['nb_path'] = f'{fname.parent.name}/{fname.name}'
|
32 |
+
with open(f'{dest_path}/{dest_name}','w') as f:
|
33 |
+
f.write(exporter.from_notebook_node(nb, resources=meta_jekyll)[0])
|
34 |
+
|
35 |
+
def convert_all(folder, dest_path='.', force_all=False):
|
36 |
+
"Convert modified notebooks in `folder` to html pages in `dest_path`."
|
37 |
+
path = Path(folder)
|
38 |
+
|
39 |
+
changed_cnt = 0
|
40 |
+
for fname in path.glob("*.ipynb"):
|
41 |
+
# only rebuild modified files
|
42 |
+
fname_out = Path(dest_path)/fname.with_suffix('.html').name
|
43 |
+
if not force_all and fname_out.exists():
|
44 |
+
in_mod = os.path.getmtime(fname)
|
45 |
+
out_mod = os.path.getmtime(fname_out)
|
46 |
+
if in_mod < out_mod: continue
|
47 |
+
|
48 |
+
print(f"converting: {fname} => {fname_out}")
|
49 |
+
changed_cnt += 1
|
50 |
+
convert_nb(fname, dest_path=dest_path)
|
51 |
+
if not changed_cnt: print("No notebooks were modified")
|
DeOldify/fastai/gen_doc/core.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..core import *
|
2 |
+
import re
|
3 |
+
|
4 |
+
def strip_fastai(s): return re.sub(r'^fastai\.', '', s)
|
5 |
+
|
DeOldify/fastai/gen_doc/docstrings.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/openstack/rally/blob/master/rally/common/plugin/info.py
|
2 |
+
# Copyright 2015: Mirantis Inc.
|
3 |
+
# All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
6 |
+
# not use this file except in compliance with the License. You may obtain
|
7 |
+
# a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
13 |
+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
14 |
+
# License for the specific language governing permissions and limitations
|
15 |
+
# under the License.
|
16 |
+
|
17 |
+
import re
|
18 |
+
import sys
|
19 |
+
|
20 |
+
__all__ = ['parse_docstring']
|
21 |
+
|
22 |
+
|
23 |
+
FIELDS = 'param|val' # supported fields
|
24 |
+
PARAM_OR_RETURN_REGEX = re.compile(f":(?:{FIELDS}|return)")
|
25 |
+
RETURN_REGEX = re.compile(":return: (?P<doc>.*)", re.S)
|
26 |
+
NEW_REGEX = re.compile(f":(?P<field>{FIELDS}) (?P<name>[\*\w]+): (?P<doc>.*?)"
|
27 |
+
f"(?:(?=:(?:{FIELDS}|return|raises))|\Z)", re.S)
|
28 |
+
|
29 |
+
def trim(docstring):
|
30 |
+
"""trim function from PEP-257"""
|
31 |
+
if not docstring:
|
32 |
+
return ""
|
33 |
+
# Convert tabs to spaces (following the normal Python rules)
|
34 |
+
# and split into a list of lines:
|
35 |
+
lines = docstring.expandtabs().splitlines()
|
36 |
+
# Determine minimum indentation (first line doesn't count):
|
37 |
+
indent = sys.maxsize
|
38 |
+
for line in lines[1:]:
|
39 |
+
stripped = line.lstrip()
|
40 |
+
if stripped:
|
41 |
+
indent = min(indent, len(line) - len(stripped))
|
42 |
+
# Remove indentation (first line is special):
|
43 |
+
trimmed = [lines[0].strip()]
|
44 |
+
if indent < sys.maxsize:
|
45 |
+
for line in lines[1:]:
|
46 |
+
trimmed.append(line[indent:].rstrip())
|
47 |
+
# Strip off trailing and leading blank lines:
|
48 |
+
while trimmed and not trimmed[-1]:
|
49 |
+
trimmed.pop()
|
50 |
+
while trimmed and not trimmed[0]:
|
51 |
+
trimmed.pop(0)
|
52 |
+
|
53 |
+
# Current code/unittests expects a line return at
|
54 |
+
# end of multiline docstrings
|
55 |
+
# workaround expected behavior from unittests
|
56 |
+
if "\n" in docstring:
|
57 |
+
trimmed.append("")
|
58 |
+
|
59 |
+
# Return a single string:
|
60 |
+
return "\n".join(trimmed)
|
61 |
+
|
62 |
+
|
63 |
+
def reindent(string):
|
64 |
+
return "\n".join(l.strip() for l in string.strip().split("\n"))
|
65 |
+
|
66 |
+
|
67 |
+
def parse_docstring(docstring):
|
68 |
+
"""Parse the docstring into its components.
|
69 |
+
|
70 |
+
:return: a dictionary of form
|
71 |
+
{
|
72 |
+
"short_description": ...,
|
73 |
+
"long_description": ...,
|
74 |
+
"params": [{"name": ..., "doc": ...}, ...],
|
75 |
+
"vals": [{"name": ..., "doc": ...}, ...],
|
76 |
+
"return": ...
|
77 |
+
}
|
78 |
+
"""
|
79 |
+
|
80 |
+
short_description = long_description = return_str = ""
|
81 |
+
args = []
|
82 |
+
|
83 |
+
if docstring:
|
84 |
+
docstring = trim(docstring.lstrip("\n"))
|
85 |
+
|
86 |
+
lines = docstring.split("\n", 1)
|
87 |
+
short_description = lines[0]
|
88 |
+
|
89 |
+
if len(lines) > 1:
|
90 |
+
long_description = lines[1].strip()
|
91 |
+
|
92 |
+
params_return_desc = None
|
93 |
+
|
94 |
+
match = PARAM_OR_RETURN_REGEX.search(long_description)
|
95 |
+
if match:
|
96 |
+
long_desc_end = match.start()
|
97 |
+
params_return_desc = long_description[long_desc_end:].strip()
|
98 |
+
long_description = long_description[:long_desc_end].rstrip()
|
99 |
+
|
100 |
+
if params_return_desc:
|
101 |
+
args = [
|
102 |
+
{"name": name, "doc": trim(doc), "field": field}
|
103 |
+
for field, name, doc in NEW_REGEX.findall(params_return_desc)
|
104 |
+
]
|
105 |
+
match = RETURN_REGEX.search(params_return_desc)
|
106 |
+
if match:
|
107 |
+
return_str = reindent(match.group("doc"))
|
108 |
+
comments = {p['name']: p['doc'] for p in args}
|
109 |
+
return {
|
110 |
+
"short_description": short_description,
|
111 |
+
"long_description": long_description,
|
112 |
+
"args": args,
|
113 |
+
"comments": comments,
|
114 |
+
"return": return_str
|
115 |
+
}
|
116 |
+
|
117 |
+
|
118 |
+
class InfoMixin(object):
|
119 |
+
|
120 |
+
@classmethod
|
121 |
+
def _get_doc(cls):
|
122 |
+
"""Return documentary of class
|
123 |
+
|
124 |
+
By default it returns docstring of class, but it can be overridden
|
125 |
+
for example for cases like merging own docstring with parent
|
126 |
+
"""
|
127 |
+
return cls.__doc__
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def get_info(cls):
|
131 |
+
doc = parse_docstring(cls._get_doc())
|
132 |
+
|
133 |
+
return {
|
134 |
+
"name": cls.get_name(),
|
135 |
+
"platform": cls.get_platform(),
|
136 |
+
"module": cls.__module__,
|
137 |
+
"title": doc["short_description"],
|
138 |
+
"description": doc["long_description"],
|
139 |
+
"parameters": doc["params"],
|
140 |
+
"schema": getattr(cls, "CONFIG_SCHEMA", None),
|
141 |
+
"return": doc["return"]
|
142 |
+
}
|
DeOldify/fastai/gen_doc/doctest.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, re, json, pprint
|
2 |
+
from pathlib import Path
|
3 |
+
from collections import defaultdict
|
4 |
+
from inspect import currentframe, getframeinfo, ismodule
|
5 |
+
|
6 |
+
__all__ = ['this_tests']
|
7 |
+
|
8 |
+
DB_NAME = 'test_registry.json'
|
9 |
+
|
10 |
+
def _json_set_default(obj):
|
11 |
+
if isinstance(obj, set): return list(obj)
|
12 |
+
raise TypeError
|
13 |
+
|
14 |
+
class TestRegistry:
|
15 |
+
"Tests register which API they validate using this class."
|
16 |
+
registry = defaultdict(list)
|
17 |
+
this_tests_check = None
|
18 |
+
missing_this_tests = set()
|
19 |
+
|
20 |
+
# logic for checking whether each test calls `this_tests`:
|
21 |
+
# 1. `this_tests_check` is set to True during test's 'setup' stage if it wasn't skipped
|
22 |
+
# 2. if the test is dynamically skipped `this_tests_check` is set to False
|
23 |
+
# 3. `this_tests` sets this flag to False when it's successfully completes
|
24 |
+
# 4. if during the 'teardown' stage `this_tests_check` is still True then we
|
25 |
+
# know that this test needs `this_tests_check`
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def this_tests(*funcs):
|
29 |
+
prev_frame = currentframe().f_back.f_back
|
30 |
+
file_name, lineno, test_name, _, _ = getframeinfo(prev_frame)
|
31 |
+
parent_func_lineno, _ = get_parent_func(lineno, get_lines(file_name))
|
32 |
+
entry = {'file': relative_test_path(file_name), 'test': test_name , 'line': parent_func_lineno}
|
33 |
+
for func in funcs:
|
34 |
+
if func == 'na':
|
35 |
+
# special case when we can't find a function to declare, e.g.
|
36 |
+
# when attributes are tested
|
37 |
+
continue
|
38 |
+
try:
|
39 |
+
func_fq = get_func_fq_name(func)
|
40 |
+
except:
|
41 |
+
raise Exception(f"'{func}' is not a function") from None
|
42 |
+
if re.match(r'fastai\.', func_fq):
|
43 |
+
if entry not in TestRegistry.registry[func_fq]:
|
44 |
+
TestRegistry.registry[func_fq].append(entry)
|
45 |
+
else:
|
46 |
+
raise Exception(f"'{func}' is not in the fastai API") from None
|
47 |
+
TestRegistry.this_tests_check = False
|
48 |
+
|
49 |
+
def this_tests_check_on():
|
50 |
+
TestRegistry.this_tests_check = True
|
51 |
+
|
52 |
+
def this_tests_check_off():
|
53 |
+
TestRegistry.this_tests_check = False
|
54 |
+
|
55 |
+
def this_tests_check_run(file_name, test_name):
|
56 |
+
if TestRegistry.this_tests_check:
|
57 |
+
TestRegistry.missing_this_tests.add(f"{file_name}::{test_name}")
|
58 |
+
|
59 |
+
def registry_save():
|
60 |
+
if TestRegistry.registry:
|
61 |
+
path = Path(__file__).parent.parent.resolve()/DB_NAME
|
62 |
+
if path.exists():
|
63 |
+
#print("\n*** Merging with the existing test registry")
|
64 |
+
with open(path, 'r') as f: old_registry = json.load(f)
|
65 |
+
TestRegistry.registry = merge_registries(old_registry, TestRegistry.registry)
|
66 |
+
#print(f"\n*** Saving test registry @ {path}")
|
67 |
+
with open(path, 'w') as f:
|
68 |
+
json.dump(obj=TestRegistry.registry, fp=f, indent=4, sort_keys=True, default=_json_set_default)
|
69 |
+
|
70 |
+
def missing_this_tests_alert():
|
71 |
+
if TestRegistry.missing_this_tests:
|
72 |
+
tests = '\n '.join(sorted(TestRegistry.missing_this_tests))
|
73 |
+
print(f"""
|
74 |
+
*** Attention ***
|
75 |
+
Please include `this_tests` call in each of the following tests:
|
76 |
+
{tests}
|
77 |
+
For details see: https://docs.fast.ai/dev/test.html#test-registry""")
|
78 |
+
|
79 |
+
# merge_registries helpers
|
80 |
+
# merge dict of lists of dict
|
81 |
+
def a2k(a): return '::'.join([a['file'], a['test']]), a['line']
|
82 |
+
def k2a(k, v): f,t = k.split('::'); return {"file": f, "line": v, "test": t}
|
83 |
+
# merge by key that is a combination of 2 values: test, file
|
84 |
+
def merge_lists(a, b):
|
85 |
+
x = dict(map(a2k, [*a, *b])) # pack + merge
|
86 |
+
return [k2a(k, v) for k,v in x.items()] # unpack
|
87 |
+
def merge_registries(a, b):
|
88 |
+
for i in b: a[i] = merge_lists(a[i], b[i]) if i in a else b[i]
|
89 |
+
return a
|
90 |
+
|
91 |
+
def this_tests(*funcs): TestRegistry.this_tests(*funcs)
|
92 |
+
|
93 |
+
def str2func(name):
|
94 |
+
"Converts 'fastai.foo.bar' into an function 'object' if such exists"
|
95 |
+
if isinstance(name, str): subpaths = name.split('.')
|
96 |
+
else: return None
|
97 |
+
|
98 |
+
module = subpaths.pop(0)
|
99 |
+
if module in sys.modules: obj = sys.modules[module]
|
100 |
+
else: return None
|
101 |
+
|
102 |
+
for subpath in subpaths:
|
103 |
+
obj = getattr(obj, subpath, None)
|
104 |
+
if obj == None: return None
|
105 |
+
return obj
|
106 |
+
|
107 |
+
def get_func_fq_name(func):
|
108 |
+
if ismodule(func): return func.__name__
|
109 |
+
if isinstance(func, str): func = str2func(func)
|
110 |
+
name = None
|
111 |
+
if hasattr(func, '__qualname__'): name = func.__qualname__
|
112 |
+
elif hasattr(func, '__name__'): name = func.__name__
|
113 |
+
elif hasattr(func, '__wrapped__'): return get_func_fq_name(func.__wrapped__)
|
114 |
+
elif hasattr(func, '__class__'): name = func.__class__.__name__
|
115 |
+
else: raise Exception(f"'{func}' is not a func or class")
|
116 |
+
return f'{func.__module__}.{name}'
|
117 |
+
|
118 |
+
def get_parent_func(lineno, lines, ignore_missing=False):
|
119 |
+
"Find any lines where `elt` is called and return the parent test function"
|
120 |
+
for idx,l in enumerate(reversed(lines[:lineno])):
|
121 |
+
if re.match(f'\s*def test', l): return (lineno - idx), l # 1 based index for github
|
122 |
+
if re.match(f'\w+', l): break # top level indent - out of function scope
|
123 |
+
if ignore_missing: return None
|
124 |
+
raise LookupError('Could not find parent function for line:', lineno, lines[:lineno])
|
125 |
+
|
126 |
+
def relative_test_path(test_file:Path)->str:
|
127 |
+
"Path relative to the `fastai` parent directory"
|
128 |
+
test_file = Path(test_file)
|
129 |
+
testdir_idx = list(reversed(test_file.parts)).index('tests')
|
130 |
+
return '/'.join(test_file.parts[-(testdir_idx+1):])
|
131 |
+
|
132 |
+
def get_lines(file):
|
133 |
+
with open(file, 'r') as f: return f.readlines()
|
DeOldify/fastai/gen_doc/gen_notebooks.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"`gen_doc.nbdoc` generates notebook documentation from module functions and links to correct places"
|
2 |
+
import pkgutil, inspect, sys,os, importlib,json,enum,warnings,nbformat,re
|
3 |
+
from IPython.core.display import display, Markdown
|
4 |
+
from nbconvert.preprocessors import ExecutePreprocessor
|
5 |
+
from nbformat.sign import NotebookNotary
|
6 |
+
from pathlib import Path
|
7 |
+
from .core import *
|
8 |
+
from .nbdoc import *
|
9 |
+
|
10 |
+
__all__ = ['create_module_page', 'update_module_page', 'import_mod',
|
11 |
+
'link_nb', 'update_notebooks', 'generate_missing_metadata', 'update_nb_metadata']
|
12 |
+
|
13 |
+
def get_empty_notebook():
|
14 |
+
"Default notbook with the minimum metadata."
|
15 |
+
#TODO: check python version and nbformat
|
16 |
+
return {'metadata': {'kernelspec': {'display_name': 'Python 3',
|
17 |
+
'language': 'python',
|
18 |
+
'name': 'python3'},
|
19 |
+
'language_info': {'codemirror_mode': {'name': 'ipython', 'version': 3},
|
20 |
+
'file_extension': '.py',
|
21 |
+
'mimetype': 'text/x-python',
|
22 |
+
'name': 'python',
|
23 |
+
'nbconvert_exporter': 'python',
|
24 |
+
'pygments_lexer': 'ipython3',
|
25 |
+
'version': '3.6.6'}},
|
26 |
+
'nbformat': 4,
|
27 |
+
'nbformat_minor': 2}
|
28 |
+
|
29 |
+
def get_md_cell(source, metadata=None):
|
30 |
+
"Markdown cell containing `source` with `metadata`."
|
31 |
+
return {'cell_type': 'markdown',
|
32 |
+
'metadata': {} if metadata is None else metadata,
|
33 |
+
'source': source}
|
34 |
+
|
35 |
+
def get_empty_cell(ctype='markdown'):
|
36 |
+
"Empty cell of type `ctype`."
|
37 |
+
return {'cell_type': ctype, 'metadata': {}, 'source': []}
|
38 |
+
|
39 |
+
def get_code_cell(code, hidden=False):
|
40 |
+
"Code cell containing `code` that may be `hidden`."
|
41 |
+
return {'cell_type' : 'code',
|
42 |
+
'execution_count': 0,
|
43 |
+
'metadata' : {'hide_input': hidden, 'trusted':True},
|
44 |
+
'source' : code,
|
45 |
+
'outputs': []}
|
46 |
+
|
47 |
+
def get_doc_cell(func_name):
|
48 |
+
"Code cell with the command to show the doc of `func_name`."
|
49 |
+
code = f"show_doc({func_name})"
|
50 |
+
return get_code_cell(code, True)
|
51 |
+
|
52 |
+
def get_global_vars(mod):
|
53 |
+
"Return globally assigned variables."
|
54 |
+
# https://stackoverflow.com/questions/8820276/docstring-for-variable/31764368#31764368
|
55 |
+
import ast,re
|
56 |
+
with open(mod.__file__, 'r') as f: fstr = f.read()
|
57 |
+
flines = fstr.splitlines()
|
58 |
+
d = {}
|
59 |
+
for node in ast.walk(ast.parse(fstr)):
|
60 |
+
if isinstance(node,ast.Assign) and hasattr(node.targets[0], 'id'):
|
61 |
+
key,lineno = node.targets[0].id,node.targets[0].lineno
|
62 |
+
codestr = flines[lineno]
|
63 |
+
match = re.match(f"^({key})\s*=\s*.*", codestr)
|
64 |
+
if match and match.group(1) != '__all__': # only top level assignment
|
65 |
+
d[key] = f'`{codestr}` {get_source_link(mod, lineno)}'
|
66 |
+
return d
|
67 |
+
|
68 |
+
def write_nb(nb, nb_path, mode='w'):
|
69 |
+
with open(nb_path, mode) as f: f.write(nbformat.writes(nbformat.from_dict(nb), version=4))
|
70 |
+
|
71 |
+
class ExecuteShowDocPreprocessor(ExecutePreprocessor):
|
72 |
+
"An ExecutePreprocessor that only executes show_doc cells"
|
73 |
+
def preprocess_cell(self, cell, resources, index):
|
74 |
+
if 'source' in cell and cell.cell_type == "code":
|
75 |
+
if IMPORT_RE.search(cell['source']) or SHOW_DOC_RE.search(cell['source']):
|
76 |
+
return super().preprocess_cell(cell, resources, index)
|
77 |
+
return cell, resources
|
78 |
+
|
79 |
+
def execute_nb(fname, metadata=None, save=True, show_doc_only=False):
|
80 |
+
"Execute notebook `fname` with `metadata` for preprocessing."
|
81 |
+
# Any module used in the notebook that isn't inside must be in the same directory as this script
|
82 |
+
with open(fname) as f: nb = nbformat.read(f, as_version=4)
|
83 |
+
ep_class = ExecuteShowDocPreprocessor if show_doc_only else ExecutePreprocessor
|
84 |
+
ep = ep_class(timeout=600, kernel_name='python3')
|
85 |
+
metadata = metadata or {}
|
86 |
+
ep.preprocess(nb, metadata)
|
87 |
+
if save:
|
88 |
+
with open(fname, 'wt') as f: nbformat.write(nb, f)
|
89 |
+
NotebookNotary().sign(nb)
|
90 |
+
|
91 |
+
def _symbol_skeleton(name): return [get_doc_cell(name), get_md_cell(f"`{name}`")]
|
92 |
+
|
93 |
+
def create_module_page(mod, dest_path, force=False):
|
94 |
+
"Create the documentation notebook for module `mod_name` in path `dest_path`"
|
95 |
+
nb = get_empty_notebook()
|
96 |
+
mod_name = mod.__name__
|
97 |
+
strip_name = strip_fastai(mod_name)
|
98 |
+
init_cell = [get_md_cell(f'## Title for {strip_name} (use plain english, not module name!)'), get_md_cell('Type an introduction of the package here.')]
|
99 |
+
cells = [get_code_cell(f'from fastai.gen_doc.nbdoc import *\nfrom {mod_name} import * ', True)]
|
100 |
+
|
101 |
+
gvar_map = get_global_vars(mod)
|
102 |
+
if gvar_map: cells.append(get_md_cell('### Global Variable Definitions:'))
|
103 |
+
for name in get_exports(mod):
|
104 |
+
if name in gvar_map: cells.append(get_md_cell(gvar_map[name]))
|
105 |
+
|
106 |
+
for ft_name in get_ft_names(mod, include_inner=True):
|
107 |
+
if not hasattr(mod, ft_name):
|
108 |
+
warnings.warn(f"Module {strip_name} doesn't have a function named {ft_name}.")
|
109 |
+
continue
|
110 |
+
cells += _symbol_skeleton(ft_name)
|
111 |
+
elt = getattr(mod, ft_name)
|
112 |
+
nb['cells'] = init_cell + cells + [get_md_cell(UNDOC_HEADER)]
|
113 |
+
|
114 |
+
doc_path = get_doc_path(mod, dest_path)
|
115 |
+
write_nb(nb, doc_path, 'w' if force else 'x')
|
116 |
+
execute_nb(doc_path)
|
117 |
+
return doc_path
|
118 |
+
|
119 |
+
_default_exclude = ['.ipynb_checkpoints', '__pycache__', '__init__.py', 'imports']
|
120 |
+
|
121 |
+
def get_module_names(path_dir, exclude=None):
|
122 |
+
if exclude is None: exclude = _default_exclude
|
123 |
+
"Search a given `path_dir` and return all the modules contained inside except those in `exclude`"
|
124 |
+
files = sorted(path_dir.glob('*'), key=lambda x: (x.is_dir(), x.name), reverse=True) # directories first
|
125 |
+
res = [f'{path_dir.name}']
|
126 |
+
for f in files:
|
127 |
+
if f.is_dir() and f.name in exclude: continue # exclude directories
|
128 |
+
if any([f.name.endswith(ex) for ex in exclude]): continue # exclude extensions
|
129 |
+
|
130 |
+
if f.suffix == '.py': res.append(f'{path_dir.name}.{f.stem}')
|
131 |
+
elif f.is_dir(): res += [f'{path_dir.name}.{name}' for name in get_module_names(f)]
|
132 |
+
return res
|
133 |
+
|
134 |
+
def read_nb(fname):
|
135 |
+
"Read a notebook in `fname` and return its corresponding json"
|
136 |
+
with open(fname,'r') as f: return nbformat.reads(f.read(), as_version=4)
|
137 |
+
|
138 |
+
SHOW_DOC_RE = re.compile(r"show_doc\(([\w\.]*)")
|
139 |
+
def read_nb_content(cells, mod_name):
|
140 |
+
"Build a dictionary containing the position of the `cells`."
|
141 |
+
doc_fns = {}
|
142 |
+
for i, cell in enumerate(cells):
|
143 |
+
if cell['cell_type'] == 'code':
|
144 |
+
for match in SHOW_DOC_RE.findall(cell['source']):
|
145 |
+
doc_fns[match] = i
|
146 |
+
return doc_fns
|
147 |
+
|
148 |
+
def read_nb_types(cells):
|
149 |
+
doc_fns = {}
|
150 |
+
for i, cell in enumerate(cells):
|
151 |
+
if cell['cell_type'] == 'markdown':
|
152 |
+
match = re.match(r"^(?:<code>|`)?(\w*)\s*=\s*", cell['source'])
|
153 |
+
if match is not None: doc_fns[match.group(1)] = i
|
154 |
+
return doc_fns
|
155 |
+
|
156 |
+
def link_markdown_cells(cells, modules):
|
157 |
+
"Create documentation links for all cells in markdown with backticks."
|
158 |
+
for i, cell in enumerate(cells):
|
159 |
+
if cell['cell_type'] == 'markdown':
|
160 |
+
cell['source'] = link_docstring(modules, cell['source'])
|
161 |
+
|
162 |
+
def get_insert_idx(pos_dict, name):
|
163 |
+
"Return the position to insert a given function doc in a notebook."
|
164 |
+
keys,i = list(pos_dict.keys()),0
|
165 |
+
while i < len(keys) and str.lower(keys[i]) < str.lower(name): i+=1
|
166 |
+
if i == len(keys): return -1
|
167 |
+
else: return pos_dict[keys[i]]
|
168 |
+
|
169 |
+
def update_pos(pos_dict, start_key, nbr=2):
|
170 |
+
"Update the `pos_dict` by moving all positions after `start_key` by `nbr`."
|
171 |
+
for key,idx in pos_dict.items():
|
172 |
+
if str.lower(key) >= str.lower(start_key): pos_dict[key] += nbr
|
173 |
+
return pos_dict
|
174 |
+
|
175 |
+
def insert_cells(cells, pos_dict, ft_name, append=False):
|
176 |
+
"Insert the function doc `cells` at their correct position and updates `pos_dict`."
|
177 |
+
idx = get_insert_idx(pos_dict, ft_name)
|
178 |
+
if append or idx == -1: cells += [get_doc_cell(ft_name), get_empty_cell()]
|
179 |
+
else:
|
180 |
+
cells.insert(idx, get_doc_cell(ft_name))
|
181 |
+
cells.insert(idx+1, get_empty_cell())
|
182 |
+
pos_dict = update_pos(pos_dict, ft_name, 2)
|
183 |
+
return cells, pos_dict
|
184 |
+
|
185 |
+
def get_doc_path(mod, dest_path):
|
186 |
+
strip_name = strip_fastai(mod.__name__)
|
187 |
+
return os.path.join(dest_path,f'{strip_name}.ipynb')
|
188 |
+
|
189 |
+
def generate_missing_metadata(dest_file):
|
190 |
+
fn = Path(dest_file)
|
191 |
+
meta_fn = fn.parent/'jekyll_metadata.ipynb'
|
192 |
+
if not fn.exists() or not meta_fn.exists(): return print('Could not find notebooks:', fn, meta_fn)
|
193 |
+
metadata_nb = read_nb(meta_fn)
|
194 |
+
|
195 |
+
if has_metadata_cell(metadata_nb['cells'], fn.name): return
|
196 |
+
nb = read_nb(fn)
|
197 |
+
jmd = nb['metadata'].get('jekyll', {})
|
198 |
+
fmt_params = ''
|
199 |
+
for k,v in jmd.items(): fmt_params += f',\n {k}={stringify(v)}'
|
200 |
+
metadata_cell = get_code_cell(f"update_nb_metadata('{Path(fn).name}'{fmt_params})", hidden=False)
|
201 |
+
metadata_nb['cells'].append(metadata_cell)
|
202 |
+
write_nb(metadata_nb, meta_fn)
|
203 |
+
|
204 |
+
def update_nb_metadata(nb_path=None, title=None, summary=None, keywords='fastai', overwrite=True, **kwargs):
|
205 |
+
"Creates jekyll metadata for given notebook path."
|
206 |
+
nb = read_nb(nb_path)
|
207 |
+
data = {'title': title, 'summary': summary, 'keywords': keywords, **kwargs}
|
208 |
+
data = {k:v for (k,v) in data.items() if v is not None} # remove none values
|
209 |
+
if not data: return
|
210 |
+
nb['metadata']['jekyll'] = data
|
211 |
+
write_nb(nb, nb_path)
|
212 |
+
NotebookNotary().sign(nb)
|
213 |
+
|
214 |
+
def has_metadata_cell(cells, fn):
|
215 |
+
for c in cells:
|
216 |
+
if re.search(f"update_nb_metadata\('{fn}'", c['source']): return c
|
217 |
+
|
218 |
+
def stringify(s): return f'\'{s}\'' if isinstance(s, str) else s
|
219 |
+
|
220 |
+
IMPORT_RE = re.compile(r"from (fastai[\.\w_]*)")
|
221 |
+
def get_imported_modules(cells, nb_module_name=''):
|
222 |
+
"Finds all submodules of notebook - sorted by submodules > top level modules > manual imports. This gives notebook imports priority"
|
223 |
+
module_names = get_top_level_modules()
|
224 |
+
nb_imports = [match.group(1) for cell in cells for match in IMPORT_RE.finditer(cell['source']) if cell['cell_type'] == 'code']
|
225 |
+
parts = nb_module_name.split('.')
|
226 |
+
parent_modules = ['.'.join(parts[:(x+1)]) for x in range_of(parts)] # Imports parent modules - a.b.c = [a, a.b, a.b.c]
|
227 |
+
all_modules = module_names + nb_imports + parent_modules
|
228 |
+
mods = [import_mod(m, ignore_errors=True) for m in all_modules]
|
229 |
+
return [m for m in mods if m is not None]
|
230 |
+
|
231 |
+
def get_top_level_modules(num_levels=1):
|
232 |
+
mod_dir = Path(import_mod('fastai').__file__).parent
|
233 |
+
filtered_n = filter(lambda x: x.count('.')<=num_levels, get_module_names(mod_dir))
|
234 |
+
return sorted(filtered_n, key=lambda s: s.count('.'), reverse=True) # Submodules first (sorted by periods)
|
235 |
+
|
236 |
+
NEW_FT_HEADER = '## New Methods - Please document or move to the undocumented section'
|
237 |
+
UNDOC_HEADER = '## Undocumented Methods - Methods moved below this line will intentionally be hidden'
|
238 |
+
def parse_sections(cells):
|
239 |
+
old_cells, undoc_cells, new_cells = [], [], []
|
240 |
+
current_section = old_cells
|
241 |
+
for cell in cells:
|
242 |
+
if cell['cell_type'] == 'markdown':
|
243 |
+
if re.match(UNDOC_HEADER, cell['source']): current_section = undoc_cells
|
244 |
+
if re.match(NEW_FT_HEADER, cell['source']): current_section = new_cells
|
245 |
+
current_section.append(cell)
|
246 |
+
undoc_cells = undoc_cells or [get_md_cell(UNDOC_HEADER)]
|
247 |
+
new_cells = new_cells or [get_md_cell(NEW_FT_HEADER)]
|
248 |
+
return old_cells, undoc_cells, new_cells
|
249 |
+
|
250 |
+
def remove_undoc_cells(cells):
|
251 |
+
old, _, _ = parse_sections(cells)
|
252 |
+
return old
|
253 |
+
|
254 |
+
# currently code vbox sub-cells mainly
|
255 |
+
def remove_code_cell_jupyter_widget_state_elem(cells):
|
256 |
+
for c in cells:
|
257 |
+
if c['cell_type'] == 'code':
|
258 |
+
if 'outputs' in c:
|
259 |
+
c['outputs'] = [l for l in c['outputs'] if not ('data' in l and 'application/vnd.jupyter.widget-view+json' in l.data)]
|
260 |
+
return cells
|
261 |
+
|
262 |
+
def update_module_page(mod, dest_path='.'):
|
263 |
+
"Update the documentation notebook of a given module."
|
264 |
+
doc_path = get_doc_path(mod, dest_path)
|
265 |
+
strip_name = strip_fastai(mod.__name__)
|
266 |
+
nb = read_nb(doc_path)
|
267 |
+
cells = nb['cells']
|
268 |
+
|
269 |
+
link_markdown_cells(cells, get_imported_modules(cells, mod.__name__))
|
270 |
+
|
271 |
+
type_dict = read_nb_types(cells)
|
272 |
+
gvar_map = get_global_vars(mod)
|
273 |
+
for name in get_exports(mod):
|
274 |
+
if name not in gvar_map: continue
|
275 |
+
code = gvar_map[name]
|
276 |
+
if name in type_dict: cells[type_dict[name]] = get_md_cell(code)
|
277 |
+
else: cells.append(get_md_cell(code))
|
278 |
+
|
279 |
+
pos_dict = read_nb_content(cells, strip_name)
|
280 |
+
ft_names = get_ft_names(mod, include_inner=True)
|
281 |
+
new_fts = list(set(ft_names) - set(pos_dict.keys()))
|
282 |
+
if new_fts: print(f'Found new fuctions for {mod}. Please document:\n{new_fts}')
|
283 |
+
existing, undoc_cells, new_cells = parse_sections(cells)
|
284 |
+
for ft_name in new_fts: new_cells.extend([get_doc_cell(ft_name), get_empty_cell()])
|
285 |
+
if len(new_cells) > 1: nb['cells'] = existing + undoc_cells + new_cells
|
286 |
+
|
287 |
+
write_nb(nb, doc_path)
|
288 |
+
return doc_path
|
289 |
+
|
290 |
+
def link_nb(nb_path):
|
291 |
+
nb = read_nb(nb_path)
|
292 |
+
cells = nb['cells']
|
293 |
+
link_markdown_cells(cells, get_imported_modules(cells, Path(nb_path).stem))
|
294 |
+
write_nb(nb, nb_path)
|
295 |
+
NotebookNotary().sign(read_nb(nb_path))
|
296 |
+
|
297 |
+
def get_module_from_notebook(doc_path):
|
298 |
+
"Find module given a source path. Assume it belongs to fastai directory"
|
299 |
+
return f'fastai.{Path(doc_path).stem}'
|
300 |
+
|
301 |
+
def check_nbconvert_version():
|
302 |
+
import nbconvert
|
303 |
+
assert nbconvert.version_info >= (5,4,0), "Please update nbconvert to >=5.4 for consistent .html output"
|
304 |
+
|
305 |
+
def update_notebooks(source_path, dest_path=None, update_html=True, document_new_fns=False,
|
306 |
+
update_nb_links=True, html_path=None, force=False):
|
307 |
+
"`source_path` can be a directory or a file. Assume all modules reside in the fastai directory."
|
308 |
+
from .convert2html import convert_nb
|
309 |
+
source_path = Path(source_path)
|
310 |
+
|
311 |
+
if source_path.is_file():
|
312 |
+
dest_path = source_path.parent if dest_path is None else Path(dest_path)
|
313 |
+
html_path = dest_path/'..'/'docs' if html_path is None else Path(html_path)
|
314 |
+
doc_path = source_path
|
315 |
+
assert source_path.suffix == '.ipynb', 'Must update from notebook or module'
|
316 |
+
if document_new_fns:
|
317 |
+
mod = import_mod(get_module_from_notebook(source_path))
|
318 |
+
if not mod: print('Could not find module for path:', source_path)
|
319 |
+
elif mod.__file__.endswith('__init__.py'): pass
|
320 |
+
else: update_module_page(mod, dest_path)
|
321 |
+
generate_missing_metadata(doc_path)
|
322 |
+
if update_nb_links:
|
323 |
+
print(f'Updating notebook {doc_path}. Please wait...')
|
324 |
+
link_nb(doc_path)
|
325 |
+
execute_nb(doc_path, {'metadata': {'path': doc_path.parent}}, show_doc_only=True)
|
326 |
+
if update_html:
|
327 |
+
check_nbconvert_version()
|
328 |
+
html_fn = html_path/doc_path.with_suffix('.html').name
|
329 |
+
if not force and html_fn.is_file():
|
330 |
+
in_mod = os.path.getmtime(doc_path)
|
331 |
+
out_mod = os.path.getmtime(html_fn)
|
332 |
+
if in_mod < out_mod: return
|
333 |
+
convert_nb(doc_path, html_path)
|
334 |
+
|
335 |
+
elif (source_path.name.startswith('fastai.')):
|
336 |
+
# Do module update
|
337 |
+
assert dest_path is not None, 'To update a module, you must specify a destination folder for where notebook resides'
|
338 |
+
mod = import_mod(source_path.name)
|
339 |
+
if not mod: return print('Could not find module for:', source_path)
|
340 |
+
doc_path = Path(dest_path)/(strip_fastai(mod.__name__)+'.ipynb')
|
341 |
+
if not doc_path.exists():
|
342 |
+
print('Notebook does not exist. Creating:', doc_path)
|
343 |
+
create_module_page(mod, dest_path)
|
344 |
+
update_notebooks(doc_path, dest_path=dest_path, update_html=update_html, document_new_fns=document_new_fns,
|
345 |
+
update_nb_links=update_nb_links, html_path=html_path)
|
346 |
+
elif source_path.is_dir():
|
347 |
+
for f in sorted(Path(source_path).glob('*.ipynb')):
|
348 |
+
update_notebooks(f, dest_path=dest_path, update_html=update_html, document_new_fns=document_new_fns,
|
349 |
+
update_nb_links=update_nb_links, html_path=html_path)
|
350 |
+
else: print('Could not resolve source file:', source_path)
|