Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/samplelib/SampleGeneratorFaceTemporal.py
628 views
1
import multiprocessing
2
import pickle
3
import time
4
import traceback
5
6
import cv2
7
import numpy as np
8
9
from core import mplib
10
from core.joblib import SubprocessGenerator, ThisThreadGenerator
11
from facelib import LandmarksProcessor
12
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
13
SampleType)
14
15
16
class SampleGeneratorFaceTemporal(SampleGeneratorBase):
17
def __init__ (self, samples_path, debug, batch_size,
18
temporal_image_count=3,
19
sample_process_options=SampleProcessor.Options(),
20
output_sample_types=[],
21
generators_count=2,
22
**kwargs):
23
super().__init__(debug, batch_size)
24
25
self.temporal_image_count = temporal_image_count
26
self.sample_process_options = sample_process_options
27
self.output_sample_types = output_sample_types
28
29
if self.debug:
30
self.generators_count = 1
31
else:
32
self.generators_count = generators_count
33
34
samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, samples_path)
35
samples_len = len(samples)
36
if samples_len == 0:
37
raise ValueError('No training data provided.')
38
39
mult_max = 1
40
l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )
41
index_host = mplib.IndexHost(l+1)
42
43
pickled_samples = pickle.dumps(samples, 4)
44
if self.debug:
45
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) )]
46
else:
47
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) ) for i in range(self.generators_count) ]
48
49
self.generator_counter = -1
50
51
def __iter__(self):
52
return self
53
54
def __next__(self):
55
self.generator_counter += 1
56
generator = self.generators[self.generator_counter % len(self.generators) ]
57
return next(generator)
58
59
def batch_func(self, param):
60
mult_max = 1
61
bs = self.batch_size
62
pickled_samples, index_host = param
63
samples = pickle.loads(pickled_samples)
64
65
while True:
66
batches = None
67
68
indexes = index_host.multi_get(bs)
69
70
for n_batch in range(self.batch_size):
71
idx = indexes[n_batch]
72
73
temporal_samples = []
74
mult = np.random.randint(mult_max)+1
75
for i in range( self.temporal_image_count ):
76
sample = samples[ idx+i*mult ]
77
try:
78
temporal_samples += SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)[0]
79
except:
80
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
81
82
if batches is None:
83
batches = [ [] for _ in range(len(temporal_samples)) ]
84
85
for i in range(len(temporal_samples)):
86
batches[i].append ( temporal_samples[i] )
87
88
yield [ np.array(batch) for batch in batches]
89
90