Path: blob/master/samplelib/SampleGeneratorFaceCelebAMaskHQ.py
628 views
import multiprocessing1import pickle2import time3import traceback4from enum import IntEnum56import cv27import numpy as np89from core import imagelib, mplib, pathex10from core.cv2ex import *11from core.interact import interact as io12from core.joblib import SubprocessGenerator, ThisThreadGenerator13from facelib import LandmarksProcessor14from samplelib import SampleGeneratorBase151617class MaskType(IntEnum):18none = 0,19cloth = 1,20ear_r = 2,21eye_g = 3,22hair = 4,23hat = 5,24l_brow = 6,25l_ear = 7,26l_eye = 8,27l_lip = 9,28mouth = 10,29neck = 11,30neck_l = 12,31nose = 13,32r_brow = 14,33r_ear = 15,34r_eye = 16,35skin = 17,36u_lip = 1837383940MaskType_to_name = {41int(MaskType.none ) : 'none',42int(MaskType.cloth ) : 'cloth',43int(MaskType.ear_r ) : 'ear_r',44int(MaskType.eye_g ) : 'eye_g',45int(MaskType.hair ) : 'hair',46int(MaskType.hat ) : 'hat',47int(MaskType.l_brow) : 'l_brow',48int(MaskType.l_ear ) : 'l_ear',49int(MaskType.l_eye ) : 'l_eye',50int(MaskType.l_lip ) : 'l_lip',51int(MaskType.mouth ) : 'mouth',52int(MaskType.neck ) : 'neck',53int(MaskType.neck_l) : 'neck_l',54int(MaskType.nose ) : 'nose',55int(MaskType.r_brow) : 'r_brow',56int(MaskType.r_ear ) : 'r_ear',57int(MaskType.r_eye ) : 'r_eye',58int(MaskType.skin ) : 'skin',59int(MaskType.u_lip ) : 'u_lip',60}6162MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() }6364class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):65def __init__ (self, root_path, debug=False, batch_size=1, resolution=256,66generators_count=4, data_format="NHWC",67**kwargs):6869super().__init__(debug, batch_size)70self.initialized = False7172dataset_path = root_path / 'CelebAMask-HQ'73if not dataset_path.exists():74raise ValueError(f'Unable to find {dataset_path}')7576images_path = dataset_path /'CelebA-HQ-img'77if not images_path.exists():78raise ValueError(f'Unable to find {images_path}')7980masks_path = dataset_path / 'CelebAMask-HQ-mask-anno'81if not masks_path.exists():82raise ValueError(f'Unable to find {masks_path}')838485if self.debug:86self.generators_count = 187else:88self.generators_count = max(1, generators_count)8990source_images_paths = pathex.get_image_paths(images_path, return_Path_class=True)91source_images_paths_len = len(source_images_paths)92mask_images_paths = pathex.get_image_paths(masks_path, subdirs=True, return_Path_class=True)9394if source_images_paths_len == 0 or len(mask_images_paths) == 0:95raise ValueError('No training data provided.')9697mask_file_id_hash = {}9899for filepath in io.progress_bar_generator(mask_images_paths, "Loading"):100stem = filepath.stem101102file_id, mask_type = stem.split('_', 1)103file_id = int(file_id)104105if file_id not in mask_file_id_hash:106mask_file_id_hash[file_id] = {}107108mask_file_id_hash[file_id][ MaskType_from_name[mask_type] ] = str(filepath.relative_to(masks_path))109110source_file_id_set = set()111112for filepath in source_images_paths:113stem = filepath.stem114115file_id = int(stem)116source_file_id_set.update ( {file_id} )117118for k in mask_file_id_hash.keys():119if k not in source_file_id_set:120io.log_err (f"Corrupted dataset: {k} not in {images_path}")121122123124if self.debug:125self.generators = [ThisThreadGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format) )]126else:127self.generators = [SubprocessGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format), start_now=False ) \128for i in range(self.generators_count) ]129130SubprocessGenerator.start_in_parallel( self.generators )131132self.generator_counter = -1133134self.initialized = True135136#overridable137def is_initialized(self):138return self.initialized139140def __iter__(self):141return self142143def __next__(self):144self.generator_counter += 1145generator = self.generators[self.generator_counter % len(self.generators) ]146return next(generator)147148def batch_func(self, param ):149images_path, masks_path, mask_file_id_hash, data_format = param150151file_ids = list(mask_file_id_hash.keys())152153shuffle_file_ids = []154155resolution = 256156random_flip = True157rotation_range=[-15,15]158scale_range=[-0.10, 0.95]159tx_range=[-0.3, 0.3]160ty_range=[-0.3, 0.3]161162random_bilinear_resize = (25,75)163motion_blur = (25, 5)164gaussian_blur = (25, 5)165166bs = self.batch_size167while True:168batches = None169170n_batch = 0171while n_batch < bs:172try:173if len(shuffle_file_ids) == 0:174shuffle_file_ids = file_ids.copy()175np.random.shuffle(shuffle_file_ids)176177file_id = shuffle_file_ids.pop()178masks = mask_file_id_hash[file_id]179image_path = images_path / f'{file_id}.jpg'180181skin_path = masks.get(MaskType.skin, None)182hair_path = masks.get(MaskType.hair, None)183hat_path = masks.get(MaskType.hat, None)184#neck_path = masks.get(MaskType.neck, None)185186img = cv2_imread(image_path).astype(np.float32) / 255.0187mask = cv2_imread(masks_path / skin_path)[...,0:1].astype(np.float32) / 255.0188189if hair_path is not None:190hair_path = masks_path / hair_path191if hair_path.exists():192hair = cv2_imread(hair_path)[...,0:1].astype(np.float32) / 255.0193mask *= (1-hair)194195if hat_path is not None:196hat_path = masks_path / hat_path197if hat_path.exists():198hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0199mask *= (1-hat)200201#if neck_path is not None:202# neck_path = masks_path / neck_path203# if neck_path.exists():204# neck = cv2_imread(neck_path)[...,0:1].astype(np.float32) / 255.0205# mask = np.clip(mask+neck, 0, 1)206207warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )208209img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 )210h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))211h = ( h + np.random.randint(360) ) % 360212s = np.clip ( s + np.random.random()-0.5, 0, 1 )213v = np.clip ( v + np.random.random()/2-0.25, 0, 1 )214img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )215216if motion_blur is not None:217chance, mb_max_size = motion_blur218chance = np.clip(chance, 0, 100)219220mblur_rnd_chance = np.random.randint(100)221mblur_rnd_kernel = np.random.randint(mb_max_size)+1222mblur_rnd_deg = np.random.randint(360)223224if mblur_rnd_chance < chance:225img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg )226227img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)228229if gaussian_blur is not None:230chance, kernel_max_size = gaussian_blur231chance = np.clip(chance, 0, 100)232233gblur_rnd_chance = np.random.randint(100)234gblur_rnd_kernel = np.random.randint(kernel_max_size)*2+1235236if gblur_rnd_chance < chance:237img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)238239if random_bilinear_resize is not None:240chance, max_size_per = random_bilinear_resize241chance = np.clip(chance, 0, 100)242pick_chance = np.random.randint(100)243resize_to = resolution - int( np.random.rand()* int(resolution*(max_size_per/100.0)) )244img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR )245img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR )246247248mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None]249mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)250mask[mask < 0.5] = 0.0251mask[mask >= 0.5] = 1.0252mask = np.clip(mask, 0, 1)253254if data_format == "NCHW":255img = np.transpose(img, (2,0,1) )256mask = np.transpose(mask, (2,0,1) )257258if batches is None:259batches = [ [], [] ]260261batches[0].append ( img )262batches[1].append ( mask )263264n_batch += 1265except:266io.log_err ( traceback.format_exc() )267268yield [ np.array(batch) for batch in batches]269270271