Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/samplelib/SampleGeneratorFaceCelebAMaskHQ.py
628 views
1
import multiprocessing
2
import pickle
3
import time
4
import traceback
5
from enum import IntEnum
6
7
import cv2
8
import numpy as np
9
10
from core import imagelib, mplib, pathex
11
from core.cv2ex import *
12
from core.interact import interact as io
13
from core.joblib import SubprocessGenerator, ThisThreadGenerator
14
from facelib import LandmarksProcessor
15
from samplelib import SampleGeneratorBase
16
17
18
class MaskType(IntEnum):
19
none = 0,
20
cloth = 1,
21
ear_r = 2,
22
eye_g = 3,
23
hair = 4,
24
hat = 5,
25
l_brow = 6,
26
l_ear = 7,
27
l_eye = 8,
28
l_lip = 9,
29
mouth = 10,
30
neck = 11,
31
neck_l = 12,
32
nose = 13,
33
r_brow = 14,
34
r_ear = 15,
35
r_eye = 16,
36
skin = 17,
37
u_lip = 18
38
39
40
41
MaskType_to_name = {
42
int(MaskType.none ) : 'none',
43
int(MaskType.cloth ) : 'cloth',
44
int(MaskType.ear_r ) : 'ear_r',
45
int(MaskType.eye_g ) : 'eye_g',
46
int(MaskType.hair ) : 'hair',
47
int(MaskType.hat ) : 'hat',
48
int(MaskType.l_brow) : 'l_brow',
49
int(MaskType.l_ear ) : 'l_ear',
50
int(MaskType.l_eye ) : 'l_eye',
51
int(MaskType.l_lip ) : 'l_lip',
52
int(MaskType.mouth ) : 'mouth',
53
int(MaskType.neck ) : 'neck',
54
int(MaskType.neck_l) : 'neck_l',
55
int(MaskType.nose ) : 'nose',
56
int(MaskType.r_brow) : 'r_brow',
57
int(MaskType.r_ear ) : 'r_ear',
58
int(MaskType.r_eye ) : 'r_eye',
59
int(MaskType.skin ) : 'skin',
60
int(MaskType.u_lip ) : 'u_lip',
61
}
62
63
MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() }
64
65
class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
66
def __init__ (self, root_path, debug=False, batch_size=1, resolution=256,
67
generators_count=4, data_format="NHWC",
68
**kwargs):
69
70
super().__init__(debug, batch_size)
71
self.initialized = False
72
73
dataset_path = root_path / 'CelebAMask-HQ'
74
if not dataset_path.exists():
75
raise ValueError(f'Unable to find {dataset_path}')
76
77
images_path = dataset_path /'CelebA-HQ-img'
78
if not images_path.exists():
79
raise ValueError(f'Unable to find {images_path}')
80
81
masks_path = dataset_path / 'CelebAMask-HQ-mask-anno'
82
if not masks_path.exists():
83
raise ValueError(f'Unable to find {masks_path}')
84
85
86
if self.debug:
87
self.generators_count = 1
88
else:
89
self.generators_count = max(1, generators_count)
90
91
source_images_paths = pathex.get_image_paths(images_path, return_Path_class=True)
92
source_images_paths_len = len(source_images_paths)
93
mask_images_paths = pathex.get_image_paths(masks_path, subdirs=True, return_Path_class=True)
94
95
if source_images_paths_len == 0 or len(mask_images_paths) == 0:
96
raise ValueError('No training data provided.')
97
98
mask_file_id_hash = {}
99
100
for filepath in io.progress_bar_generator(mask_images_paths, "Loading"):
101
stem = filepath.stem
102
103
file_id, mask_type = stem.split('_', 1)
104
file_id = int(file_id)
105
106
if file_id not in mask_file_id_hash:
107
mask_file_id_hash[file_id] = {}
108
109
mask_file_id_hash[file_id][ MaskType_from_name[mask_type] ] = str(filepath.relative_to(masks_path))
110
111
source_file_id_set = set()
112
113
for filepath in source_images_paths:
114
stem = filepath.stem
115
116
file_id = int(stem)
117
source_file_id_set.update ( {file_id} )
118
119
for k in mask_file_id_hash.keys():
120
if k not in source_file_id_set:
121
io.log_err (f"Corrupted dataset: {k} not in {images_path}")
122
123
124
125
if self.debug:
126
self.generators = [ThisThreadGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format) )]
127
else:
128
self.generators = [SubprocessGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format), start_now=False ) \
129
for i in range(self.generators_count) ]
130
131
SubprocessGenerator.start_in_parallel( self.generators )
132
133
self.generator_counter = -1
134
135
self.initialized = True
136
137
#overridable
138
def is_initialized(self):
139
return self.initialized
140
141
def __iter__(self):
142
return self
143
144
def __next__(self):
145
self.generator_counter += 1
146
generator = self.generators[self.generator_counter % len(self.generators) ]
147
return next(generator)
148
149
def batch_func(self, param ):
150
images_path, masks_path, mask_file_id_hash, data_format = param
151
152
file_ids = list(mask_file_id_hash.keys())
153
154
shuffle_file_ids = []
155
156
resolution = 256
157
random_flip = True
158
rotation_range=[-15,15]
159
scale_range=[-0.10, 0.95]
160
tx_range=[-0.3, 0.3]
161
ty_range=[-0.3, 0.3]
162
163
random_bilinear_resize = (25,75)
164
motion_blur = (25, 5)
165
gaussian_blur = (25, 5)
166
167
bs = self.batch_size
168
while True:
169
batches = None
170
171
n_batch = 0
172
while n_batch < bs:
173
try:
174
if len(shuffle_file_ids) == 0:
175
shuffle_file_ids = file_ids.copy()
176
np.random.shuffle(shuffle_file_ids)
177
178
file_id = shuffle_file_ids.pop()
179
masks = mask_file_id_hash[file_id]
180
image_path = images_path / f'{file_id}.jpg'
181
182
skin_path = masks.get(MaskType.skin, None)
183
hair_path = masks.get(MaskType.hair, None)
184
hat_path = masks.get(MaskType.hat, None)
185
#neck_path = masks.get(MaskType.neck, None)
186
187
img = cv2_imread(image_path).astype(np.float32) / 255.0
188
mask = cv2_imread(masks_path / skin_path)[...,0:1].astype(np.float32) / 255.0
189
190
if hair_path is not None:
191
hair_path = masks_path / hair_path
192
if hair_path.exists():
193
hair = cv2_imread(hair_path)[...,0:1].astype(np.float32) / 255.0
194
mask *= (1-hair)
195
196
if hat_path is not None:
197
hat_path = masks_path / hat_path
198
if hat_path.exists():
199
hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0
200
mask *= (1-hat)
201
202
#if neck_path is not None:
203
# neck_path = masks_path / neck_path
204
# if neck_path.exists():
205
# neck = cv2_imread(neck_path)[...,0:1].astype(np.float32) / 255.0
206
# mask = np.clip(mask+neck, 0, 1)
207
208
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
209
210
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 )
211
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
212
h = ( h + np.random.randint(360) ) % 360
213
s = np.clip ( s + np.random.random()-0.5, 0, 1 )
214
v = np.clip ( v + np.random.random()/2-0.25, 0, 1 )
215
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )
216
217
if motion_blur is not None:
218
chance, mb_max_size = motion_blur
219
chance = np.clip(chance, 0, 100)
220
221
mblur_rnd_chance = np.random.randint(100)
222
mblur_rnd_kernel = np.random.randint(mb_max_size)+1
223
mblur_rnd_deg = np.random.randint(360)
224
225
if mblur_rnd_chance < chance:
226
img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg )
227
228
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
229
230
if gaussian_blur is not None:
231
chance, kernel_max_size = gaussian_blur
232
chance = np.clip(chance, 0, 100)
233
234
gblur_rnd_chance = np.random.randint(100)
235
gblur_rnd_kernel = np.random.randint(kernel_max_size)*2+1
236
237
if gblur_rnd_chance < chance:
238
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)
239
240
if random_bilinear_resize is not None:
241
chance, max_size_per = random_bilinear_resize
242
chance = np.clip(chance, 0, 100)
243
pick_chance = np.random.randint(100)
244
resize_to = resolution - int( np.random.rand()* int(resolution*(max_size_per/100.0)) )
245
img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR )
246
img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR )
247
248
249
mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None]
250
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
251
mask[mask < 0.5] = 0.0
252
mask[mask >= 0.5] = 1.0
253
mask = np.clip(mask, 0, 1)
254
255
if data_format == "NCHW":
256
img = np.transpose(img, (2,0,1) )
257
mask = np.transpose(mask, (2,0,1) )
258
259
if batches is None:
260
batches = [ [], [] ]
261
262
batches[0].append ( img )
263
batches[1].append ( mask )
264
265
n_batch += 1
266
except:
267
io.log_err ( traceback.format_exc() )
268
269
yield [ np.array(batch) for batch in batches]
270
271