Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/samplelib/SampleLoader.py
623 views
1
import multiprocessing
2
import operator
3
import pickle
4
import traceback
5
from pathlib import Path
6
7
import samplelib.PackedFaceset
8
from core import pathex
9
from core.mplib import MPSharedList
10
from core.interact import interact as io
11
from core.joblib import Subprocessor
12
from DFLIMG import *
13
from facelib import FaceType, LandmarksProcessor
14
15
from .Sample import Sample, SampleType
16
17
18
class SampleLoader:
19
samples_cache = dict()
20
@staticmethod
21
def get_person_id_max_count(samples_path):
22
samples = None
23
try:
24
samples = samplelib.PackedFaceset.load(samples_path)
25
except:
26
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_path)}, {traceback.format_exc()}")
27
28
if samples is None:
29
raise ValueError("packed faceset not found.")
30
persons_name_idxs = {}
31
for sample in samples:
32
persons_name_idxs[sample.person_name] = 0
33
return len(list(persons_name_idxs.keys()))
34
35
@staticmethod
36
def load(sample_type, samples_path, subdirs=False):
37
"""
38
Return MPSharedList of samples
39
"""
40
samples_cache = SampleLoader.samples_cache
41
42
if str(samples_path) not in samples_cache.keys():
43
samples_cache[str(samples_path)] = [None]*SampleType.QTY
44
45
samples = samples_cache[str(samples_path)]
46
47
if sample_type == SampleType.IMAGE:
48
if samples[sample_type] is None:
49
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( pathex.get_image_paths(samples_path, subdirs=subdirs), "Loading") ]
50
51
elif sample_type == SampleType.FACE:
52
if samples[sample_type] is None:
53
try:
54
result = samplelib.PackedFaceset.load(samples_path)
55
except:
56
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}")
57
58
if result is not None:
59
io.log_info (f"Loaded {len(result)} packed faces from {samples_path}")
60
61
if result is None:
62
result = SampleLoader.load_face_samples( pathex.get_image_paths(samples_path, subdirs=subdirs) )
63
64
samples[sample_type] = MPSharedList(result)
65
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
66
result = SampleLoader.load (SampleType.FACE, samples_path)
67
result = SampleLoader.upgradeToFaceTemporalSortedSamples(result)
68
samples[sample_type] = MPSharedList(result)
69
70
return samples[sample_type]
71
72
@staticmethod
73
def load_face_samples ( image_paths):
74
result = FaceSamplesLoaderSubprocessor(image_paths).run()
75
sample_list = []
76
77
for filename, data in result:
78
if data is None:
79
continue
80
( face_type,
81
shape,
82
landmarks,
83
seg_ie_polys,
84
xseg_mask_compressed,
85
eyebrows_expand_mod,
86
source_filename ) = data
87
88
sample_list.append( Sample(filename=filename,
89
sample_type=SampleType.FACE,
90
face_type=FaceType.fromString (face_type),
91
shape=shape,
92
landmarks=landmarks,
93
seg_ie_polys=seg_ie_polys,
94
xseg_mask_compressed=xseg_mask_compressed,
95
eyebrows_expand_mod=eyebrows_expand_mod,
96
source_filename=source_filename,
97
))
98
return sample_list
99
100
@staticmethod
101
def upgradeToFaceTemporalSortedSamples( samples ):
102
new_s = [ (s, s.source_filename) for s in samples]
103
new_s = sorted(new_s, key=operator.itemgetter(1))
104
105
return [ s[0] for s in new_s]
106
107
108
class FaceSamplesLoaderSubprocessor(Subprocessor):
109
#override
110
def __init__(self, image_paths ):
111
self.image_paths = image_paths
112
self.image_paths_len = len(image_paths)
113
self.idxs = [*range(self.image_paths_len)]
114
self.result = [None]*self.image_paths_len
115
super().__init__('FaceSamplesLoader', FaceSamplesLoaderSubprocessor.Cli, 60)
116
117
#override
118
def on_clients_initialized(self):
119
io.progress_bar ("Loading samples", len (self.image_paths))
120
121
#override
122
def on_clients_finalized(self):
123
io.progress_bar_close()
124
125
#override
126
def process_info_generator(self):
127
for i in range(min(multiprocessing.cpu_count(), 8) ):
128
yield 'CPU%d' % (i), {}, {}
129
130
#override
131
def get_data(self, host_dict):
132
if len (self.idxs) > 0:
133
idx = self.idxs.pop(0)
134
return idx, self.image_paths[idx]
135
136
return None
137
138
#override
139
def on_data_return (self, host_dict, data):
140
self.idxs.insert(0, data[0])
141
142
#override
143
def on_result (self, host_dict, data, result):
144
idx, dflimg = result
145
self.result[idx] = (self.image_paths[idx], dflimg)
146
io.progress_bar_inc(1)
147
148
#override
149
def get_result(self):
150
return self.result
151
152
class Cli(Subprocessor.Cli):
153
#override
154
def process_data(self, data):
155
idx, filename = data
156
dflimg = DFLIMG.load (Path(filename))
157
158
if dflimg is None or not dflimg.has_data():
159
self.log_err (f"FaceSamplesLoader: {filename} is not a dfl image file.")
160
data = None
161
else:
162
data = (dflimg.get_face_type(),
163
dflimg.get_shape(),
164
dflimg.get_landmarks(),
165
dflimg.get_seg_ie_polys(),
166
dflimg.get_xseg_mask_compressed(),
167
dflimg.get_eyebrows_expand_mod(),
168
dflimg.get_source_filename() )
169
170
return idx, data
171
172
#override
173
def get_data_name (self, data):
174
#return string identificator of your data
175
return data[1]
176
177