Path: blob/master/samplelib/SampleGeneratorFaceXSeg.py
628 views
import multiprocessing1import pickle2import time3import traceback4from enum import IntEnum56import cv27import numpy as np8from pathlib import Path9from core import imagelib, mplib, pathex10from core.imagelib import sd11from core.cv2ex import *12from core.interact import interact as io13from core.joblib import Subprocessor, SubprocessGenerator, ThisThreadGenerator14from facelib import LandmarksProcessor15from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType)1617class SampleGeneratorFaceXSeg(SampleGeneratorBase):18def __init__ (self, paths, debug=False, batch_size=1, resolution=256, face_type=None,19generators_count=4, data_format="NHWC",20**kwargs):2122super().__init__(debug, batch_size)23self.initialized = False2425samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] )26seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run()2728if len(seg_sample_idxs) == 0:29seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples, count_xseg_mask=True).run()30if len(seg_sample_idxs) == 0:31raise Exception(f"No segmented faces found.")32else:33io.log_info(f"Using {len(seg_sample_idxs)} xseg labeled samples.")34else:35io.log_info(f"Using {len(seg_sample_idxs)} segmented samples.")3637if self.debug:38self.generators_count = 139else:40self.generators_count = max(1, generators_count)4142args = (samples, seg_sample_idxs, resolution, face_type, data_format)43if self.debug:44self.generators = [ThisThreadGenerator ( self.batch_func, args )]45else:46self.generators = [SubprocessGenerator ( self.batch_func, args, start_now=False ) for i in range(self.generators_count) ]4748SubprocessGenerator.start_in_parallel( self.generators )4950self.generator_counter = -15152self.initialized = True5354#overridable55def is_initialized(self):56return self.initialized5758def __iter__(self):59return self6061def __next__(self):62self.generator_counter += 163generator = self.generators[self.generator_counter % len(self.generators) ]64return next(generator)6566def batch_func(self, param ):67samples, seg_sample_idxs, resolution, face_type, data_format = param6869shuffle_idxs = []70bg_shuffle_idxs = []7172random_flip = True73rotation_range=[-10,10]74scale_range=[-0.05, 0.05]75tx_range=[-0.05, 0.05]76ty_range=[-0.05, 0.05]7778random_bilinear_resize_chance, random_bilinear_resize_max_size_per = 25,7579sharpen_chance, sharpen_kernel_max_size = 25, 580motion_blur_chance, motion_blur_mb_max_size = 25, 581gaussian_blur_chance, gaussian_blur_kernel_max_size = 25, 582random_jpeg_compress_chance = 258384def gen_img_mask(sample):85img = sample.load_bgr()86h,w,c = img.shape8788if sample.seg_ie_polys.has_polys():89mask = np.zeros ((h,w,1), dtype=np.float32)90sample.seg_ie_polys.overlay_mask(mask)91elif sample.has_xseg_mask():92mask = sample.get_xseg_mask()93mask[mask < 0.5] = 0.094mask[mask >= 0.5] = 1.095else:96raise Exception(f'no mask in sample {sample.filename}')9798if face_type == sample.face_type:99if w != resolution:100img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4 )101mask = cv2.resize( mask, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4 )102else:103mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type)104img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )105mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 )106107if len(mask.shape) == 2:108mask = mask[...,None]109return img, mask110111bs = self.batch_size112while True:113batches = [ [], [] ]114115n_batch = 0116while n_batch < bs:117try:118if len(shuffle_idxs) == 0:119shuffle_idxs = seg_sample_idxs.copy()120np.random.shuffle(shuffle_idxs)121sample = samples[shuffle_idxs.pop()]122img, mask = gen_img_mask(sample)123124if np.random.randint(2) == 0:125if len(bg_shuffle_idxs) == 0:126bg_shuffle_idxs = seg_sample_idxs.copy()127np.random.shuffle(bg_shuffle_idxs)128bg_sample = samples[bg_shuffle_idxs.pop()]129130bg_img, bg_mask = gen_img_mask(bg_sample)131132bg_wp = imagelib.gen_warp_params(resolution, True, rotation_range=[-180,180], scale_range=[-0.10, 0.10], tx_range=[-0.10, 0.10], ty_range=[-0.10, 0.10] )133bg_img = imagelib.warp_by_params (bg_wp, bg_img, can_warp=False, can_transform=True, can_flip=True, border_replicate=True)134bg_mask = imagelib.warp_by_params (bg_wp, bg_mask, can_warp=False, can_transform=True, can_flip=True, border_replicate=False)135bg_img = bg_img*(1-bg_mask)136if np.random.randint(2) == 0:137bg_img = imagelib.apply_random_hsv_shift(bg_img)138else:139bg_img = imagelib.apply_random_rgb_levels(bg_img)140141c_mask = 1.0 - (1-bg_mask) * (1-mask)142rnd = 0.15 + np.random.uniform()*0.85143img = img*(c_mask) + img*(1-c_mask)*rnd + bg_img*(1-c_mask)*(1-rnd)144145warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )146img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=True)147mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False)148149img = np.clip(img.astype(np.float32), 0, 1)150mask[mask < 0.5] = 0.0151mask[mask >= 0.5] = 1.0152mask = np.clip(mask, 0, 1)153154if np.random.randint(2) == 0:155# random face flare156krn = np.random.randint( resolution//4, resolution )157krn = krn - krn % 2 + 1158img = img + cv2.GaussianBlur(img*mask, (krn,krn), 0)159160if np.random.randint(2) == 0:161# random bg flare162krn = np.random.randint( resolution//4, resolution )163krn = krn - krn % 2 + 1164img = img + cv2.GaussianBlur(img*(1-mask), (krn,krn), 0)165166if np.random.randint(2) == 0:167img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution]))168else:169img = imagelib.apply_random_rgb_levels(img, mask=sd.random_circle_faded ([resolution,resolution]))170171if np.random.randint(2) == 0:172img = imagelib.apply_random_sharpen( img, sharpen_chance, sharpen_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution]))173else:174img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution]))175img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution]))176177if np.random.randint(2) == 0:178img = imagelib.apply_random_nearest_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution]))179else:180img = imagelib.apply_random_bilinear_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution]))181img = np.clip(img, 0, 1)182183img = imagelib.apply_random_jpeg_compress( img, random_jpeg_compress_chance, mask=sd.random_circle_faded ([resolution,resolution]))184185if data_format == "NCHW":186img = np.transpose(img, (2,0,1) )187mask = np.transpose(mask, (2,0,1) )188189batches[0].append ( img )190batches[1].append ( mask )191192n_batch += 1193except:194io.log_err ( traceback.format_exc() )195196yield [ np.array(batch) for batch in batches]197198class SegmentedSampleFilterSubprocessor(Subprocessor):199#override200def __init__(self, samples, count_xseg_mask=False ):201self.samples = samples202self.samples_len = len(self.samples)203self.count_xseg_mask = count_xseg_mask204205self.idxs = [*range(self.samples_len)]206self.result = []207super().__init__('SegmentedSampleFilterSubprocessor', SegmentedSampleFilterSubprocessor.Cli, 60)208209#override210def process_info_generator(self):211for i in range(multiprocessing.cpu_count()):212yield 'CPU%d' % (i), {}, {'samples':self.samples, 'count_xseg_mask':self.count_xseg_mask}213214#override215def on_clients_initialized(self):216io.progress_bar ("Filtering", self.samples_len)217218#override219def on_clients_finalized(self):220io.progress_bar_close()221222#override223def get_data(self, host_dict):224if len (self.idxs) > 0:225return self.idxs.pop(0)226227return None228229#override230def on_data_return (self, host_dict, data):231self.idxs.insert(0, data)232233#override234def on_result (self, host_dict, data, result):235idx, is_ok = result236if is_ok:237self.result.append(idx)238io.progress_bar_inc(1)239def get_result(self):240return self.result241242class Cli(Subprocessor.Cli):243#overridable optional244def on_initialize(self, client_dict):245self.samples = client_dict['samples']246self.count_xseg_mask = client_dict['count_xseg_mask']247248def process_data(self, idx):249if self.count_xseg_mask:250return idx, self.samples[idx].has_xseg_mask()251else:252return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0253254"""255bg_path = None256for path in paths:257bg_path = Path(path) / 'backgrounds'258if bg_path.exists():259260break261if bg_path is None:262io.log_info(f'Random backgrounds will not be used. Place no face jpg images to aligned\backgrounds folder. ')263bg_pathes = None264else:265bg_pathes = pathex.get_image_paths(bg_path, image_extensions=['.jpg'], return_Path_class=True)266io.log_info(f'Using {len(bg_pathes)} random backgrounds from {bg_path}')267268if bg_pathes is not None:269bg_path = bg_pathes[ np.random.randint(len(bg_pathes)) ]270271bg_img = cv2_imread(bg_path)272if bg_img is not None:273bg_img = bg_img.astype(np.float32) / 255.0274bg_img = imagelib.normalize_channels(bg_img, 3)275276bg_img = imagelib.random_crop(bg_img, resolution, resolution)277bg_img = cv2.resize(bg_img, (resolution, resolution), interpolation=cv2.INTER_LINEAR)278279if np.random.randint(2) == 0:280bg_img = imagelib.apply_random_hsv_shift(bg_img)281else:282bg_img = imagelib.apply_random_rgb_levels(bg_img)283284bg_wp = imagelib.gen_warp_params(resolution, True, rotation_range=[-180,180], scale_range=[0,0], tx_range=[0,0], ty_range=[0,0])285bg_img = imagelib.warp_by_params (bg_wp, bg_img, can_warp=False, can_transform=True, can_flip=True, border_replicate=True)286287bg = img*(1-mask)288fg = img*mask289290c_mask = sd.random_circle_faded ([resolution,resolution])291bg = ( bg_img*c_mask + bg*(1-c_mask) )*(1-mask)292293img = fg+bg294295else:296"""297298