Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/mainscripts/Sorter.py
628 views
1
import math
2
import multiprocessing
3
import operator
4
import os
5
import sys
6
import tempfile
7
from functools import cmp_to_key
8
from pathlib import Path
9
10
import cv2
11
import numpy as np
12
from numpy import linalg as npla
13
14
from core import imagelib, mathlib, pathex
15
from core.cv2ex import *
16
from core.imagelib import estimate_sharpness
17
from core.interact import interact as io
18
from core.joblib import Subprocessor
19
from core.leras import nn
20
from DFLIMG import *
21
from facelib import LandmarksProcessor
22
23
24
class BlurEstimatorSubprocessor(Subprocessor):
25
class Cli(Subprocessor.Cli):
26
def on_initialize(self, client_dict):
27
self.estimate_motion_blur = client_dict['estimate_motion_blur']
28
29
#override
30
def process_data(self, data):
31
filepath = Path( data[0] )
32
dflimg = DFLIMG.load (filepath)
33
34
if dflimg is None or not dflimg.has_data():
35
self.log_err (f"{filepath.name} is not a dfl image file")
36
return [ str(filepath), 0 ]
37
else:
38
image = cv2_imread( str(filepath) )
39
40
face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks())
41
image = (image*face_mask).astype(np.uint8)
42
43
44
if self.estimate_motion_blur:
45
value = cv2.Laplacian(image, cv2.CV_64F, ksize=11).var()
46
else:
47
value = estimate_sharpness(image)
48
49
return [ str(filepath), value ]
50
51
52
#override
53
def get_data_name (self, data):
54
#return string identificator of your data
55
return data[0]
56
57
#override
58
def __init__(self, input_data, estimate_motion_blur=False ):
59
self.input_data = input_data
60
self.estimate_motion_blur = estimate_motion_blur
61
self.img_list = []
62
self.trash_img_list = []
63
super().__init__('BlurEstimator', BlurEstimatorSubprocessor.Cli, 60)
64
65
#override
66
def on_clients_initialized(self):
67
io.progress_bar ("", len (self.input_data))
68
69
#override
70
def on_clients_finalized(self):
71
io.progress_bar_close ()
72
73
#override
74
def process_info_generator(self):
75
cpu_count = multiprocessing.cpu_count()
76
io.log_info(f'Running on {cpu_count} CPUs')
77
78
for i in range(cpu_count):
79
yield 'CPU%d' % (i), {}, {'estimate_motion_blur':self.estimate_motion_blur}
80
81
#override
82
def get_data(self, host_dict):
83
if len (self.input_data) > 0:
84
return self.input_data.pop(0)
85
86
return None
87
88
#override
89
def on_data_return (self, host_dict, data):
90
self.input_data.insert(0, data)
91
92
#override
93
def on_result (self, host_dict, data, result):
94
if result[1] == 0:
95
self.trash_img_list.append ( result )
96
else:
97
self.img_list.append ( result )
98
99
io.progress_bar_inc(1)
100
101
#override
102
def get_result(self):
103
return self.img_list, self.trash_img_list
104
105
106
def sort_by_blur(input_path):
107
io.log_info ("Sorting by blur...")
108
109
img_list = [ (filename,[]) for filename in pathex.get_image_paths(input_path) ]
110
img_list, trash_img_list = BlurEstimatorSubprocessor (img_list).run()
111
112
io.log_info ("Sorting...")
113
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
114
115
return img_list, trash_img_list
116
117
def sort_by_motion_blur(input_path):
118
io.log_info ("Sorting by motion blur...")
119
120
img_list = [ (filename,[]) for filename in pathex.get_image_paths(input_path) ]
121
img_list, trash_img_list = BlurEstimatorSubprocessor (img_list, estimate_motion_blur=True).run()
122
123
io.log_info ("Sorting...")
124
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
125
126
return img_list, trash_img_list
127
128
def sort_by_face_yaw(input_path):
129
io.log_info ("Sorting by face yaw...")
130
img_list = []
131
trash_img_list = []
132
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
133
filepath = Path(filepath)
134
135
dflimg = DFLIMG.load (filepath)
136
137
if dflimg is None or not dflimg.has_data():
138
io.log_err (f"{filepath.name} is not a dfl image file")
139
trash_img_list.append ( [str(filepath)] )
140
continue
141
142
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] )
143
144
img_list.append( [str(filepath), yaw ] )
145
146
io.log_info ("Sorting...")
147
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
148
149
return img_list, trash_img_list
150
151
def sort_by_face_pitch(input_path):
152
io.log_info ("Sorting by face pitch...")
153
img_list = []
154
trash_img_list = []
155
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
156
filepath = Path(filepath)
157
158
dflimg = DFLIMG.load (filepath)
159
160
if dflimg is None or not dflimg.has_data():
161
io.log_err (f"{filepath.name} is not a dfl image file")
162
trash_img_list.append ( [str(filepath)] )
163
continue
164
165
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] )
166
167
img_list.append( [str(filepath), pitch ] )
168
169
io.log_info ("Sorting...")
170
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
171
172
return img_list, trash_img_list
173
174
def sort_by_face_source_rect_size(input_path):
175
io.log_info ("Sorting by face rect size...")
176
img_list = []
177
trash_img_list = []
178
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
179
filepath = Path(filepath)
180
181
dflimg = DFLIMG.load (filepath)
182
183
if dflimg is None or not dflimg.has_data():
184
io.log_err (f"{filepath.name} is not a dfl image file")
185
trash_img_list.append ( [str(filepath)] )
186
continue
187
188
source_rect = dflimg.get_source_rect()
189
rect_area = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32))
190
191
img_list.append( [str(filepath), rect_area ] )
192
193
io.log_info ("Sorting...")
194
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
195
196
return img_list, trash_img_list
197
198
199
200
class HistSsimSubprocessor(Subprocessor):
201
class Cli(Subprocessor.Cli):
202
#override
203
def process_data(self, data):
204
img_list = []
205
for x in data:
206
img = cv2_imread(x)
207
img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]),
208
cv2.calcHist([img], [1], None, [256], [0, 256]),
209
cv2.calcHist([img], [2], None, [256], [0, 256])
210
])
211
212
img_list_len = len(img_list)
213
for i in range(img_list_len-1):
214
min_score = float("inf")
215
j_min_score = i+1
216
for j in range(i+1,len(img_list)):
217
score = cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \
218
cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \
219
cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA)
220
if score < min_score:
221
min_score = score
222
j_min_score = j
223
img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1]
224
225
self.progress_bar_inc(1)
226
227
return img_list
228
229
#override
230
def get_data_name (self, data):
231
return "Bunch of images"
232
233
#override
234
def __init__(self, img_list ):
235
self.img_list = img_list
236
self.img_list_len = len(img_list)
237
238
slice_count = 20000
239
sliced_count = self.img_list_len // slice_count
240
241
if sliced_count > 12:
242
sliced_count = 11.9
243
slice_count = int(self.img_list_len / sliced_count)
244
sliced_count = self.img_list_len // slice_count
245
246
self.img_chunks_list = [ self.img_list[i*slice_count : (i+1)*slice_count] for i in range(sliced_count) ] + \
247
[ self.img_list[sliced_count*slice_count:] ]
248
249
self.result = []
250
super().__init__('HistSsim', HistSsimSubprocessor.Cli, 0)
251
252
#override
253
def process_info_generator(self):
254
cpu_count = len(self.img_chunks_list)
255
io.log_info(f'Running on {cpu_count} threads')
256
for i in range(cpu_count):
257
yield 'CPU%d' % (i), {'i':i}, {}
258
259
#override
260
def on_clients_initialized(self):
261
io.progress_bar ("Sorting", len(self.img_list))
262
io.progress_bar_inc(len(self.img_chunks_list))
263
264
#override
265
def on_clients_finalized(self):
266
io.progress_bar_close()
267
268
#override
269
def get_data(self, host_dict):
270
if len (self.img_chunks_list) > 0:
271
return self.img_chunks_list.pop(0)
272
return None
273
274
#override
275
def on_data_return (self, host_dict, data):
276
raise Exception("Fail to process data. Decrease number of images and try again.")
277
278
#override
279
def on_result (self, host_dict, data, result):
280
self.result += result
281
return 0
282
283
#override
284
def get_result(self):
285
return self.result
286
287
def sort_by_hist(input_path):
288
io.log_info ("Sorting by histogram similarity...")
289
img_list = HistSsimSubprocessor(pathex.get_image_paths(input_path)).run()
290
return img_list, []
291
292
class HistDissimSubprocessor(Subprocessor):
293
class Cli(Subprocessor.Cli):
294
#override
295
def on_initialize(self, client_dict):
296
self.img_list = client_dict['img_list']
297
self.img_list_len = len(self.img_list)
298
299
#override
300
def process_data(self, data):
301
i = data[0]
302
score_total = 0
303
for j in range( 0, self.img_list_len):
304
if i == j:
305
continue
306
score_total += cv2.compareHist(self.img_list[i][1], self.img_list[j][1], cv2.HISTCMP_BHATTACHARYYA)
307
308
return score_total
309
310
#override
311
def get_data_name (self, data):
312
#return string identificator of your data
313
return self.img_list[data[0]][0]
314
315
#override
316
def __init__(self, img_list ):
317
self.img_list = img_list
318
self.img_list_range = [i for i in range(0, len(img_list) )]
319
self.result = []
320
super().__init__('HistDissim', HistDissimSubprocessor.Cli, 60)
321
322
#override
323
def on_clients_initialized(self):
324
io.progress_bar ("Sorting", len (self.img_list) )
325
326
#override
327
def on_clients_finalized(self):
328
io.progress_bar_close()
329
330
#override
331
def process_info_generator(self):
332
cpu_count = min(multiprocessing.cpu_count(), 8)
333
io.log_info(f'Running on {cpu_count} CPUs')
334
for i in range(cpu_count):
335
yield 'CPU%d' % (i), {}, {'img_list' : self.img_list}
336
337
#override
338
def get_data(self, host_dict):
339
if len (self.img_list_range) > 0:
340
return [self.img_list_range.pop(0)]
341
342
return None
343
344
#override
345
def on_data_return (self, host_dict, data):
346
self.img_list_range.insert(0, data[0])
347
348
#override
349
def on_result (self, host_dict, data, result):
350
self.img_list[data[0]][2] = result
351
io.progress_bar_inc(1)
352
353
#override
354
def get_result(self):
355
return self.img_list
356
357
def sort_by_hist_dissim(input_path):
358
io.log_info ("Sorting by histogram dissimilarity...")
359
360
img_list = []
361
trash_img_list = []
362
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
363
filepath = Path(filepath)
364
365
dflimg = DFLIMG.load (filepath)
366
367
image = cv2_imread(str(filepath))
368
369
if dflimg is not None and dflimg.has_data():
370
face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks())
371
image = (image*face_mask).astype(np.uint8)
372
373
img_list.append ([str(filepath), cv2.calcHist([cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)], [0], None, [256], [0, 256]), 0 ])
374
375
img_list = HistDissimSubprocessor(img_list).run()
376
377
io.log_info ("Sorting...")
378
img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True)
379
380
return img_list, trash_img_list
381
382
def sort_by_brightness(input_path):
383
io.log_info ("Sorting by brightness...")
384
img_list = [ [x, np.mean ( cv2.cvtColor(cv2_imread(x), cv2.COLOR_BGR2HSV)[...,2].flatten() )] for x in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading") ]
385
io.log_info ("Sorting...")
386
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
387
return img_list, []
388
389
def sort_by_hue(input_path):
390
io.log_info ("Sorting by hue...")
391
img_list = [ [x, np.mean ( cv2.cvtColor(cv2_imread(x), cv2.COLOR_BGR2HSV)[...,0].flatten() )] for x in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading") ]
392
io.log_info ("Sorting...")
393
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
394
return img_list, []
395
396
def sort_by_black(input_path):
397
io.log_info ("Sorting by amount of black pixels...")
398
399
img_list = []
400
for x in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
401
img = cv2_imread(x)
402
img_list.append ([x, img[(img == 0)].size ])
403
404
io.log_info ("Sorting...")
405
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=False)
406
407
return img_list, []
408
409
def sort_by_origname(input_path):
410
io.log_info ("Sort by original filename...")
411
412
img_list = []
413
trash_img_list = []
414
for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"):
415
filepath = Path(filepath)
416
417
dflimg = DFLIMG.load (filepath)
418
419
if dflimg is None or not dflimg.has_data():
420
io.log_err (f"{filepath.name} is not a dfl image file")
421
trash_img_list.append( [str(filepath)] )
422
continue
423
424
img_list.append( [str(filepath), dflimg.get_source_filename()] )
425
426
io.log_info ("Sorting...")
427
img_list = sorted(img_list, key=operator.itemgetter(1))
428
return img_list, trash_img_list
429
430
def sort_by_oneface_in_image(input_path):
431
io.log_info ("Sort by one face in images...")
432
image_paths = pathex.get_image_paths(input_path)
433
a = np.array ([ ( int(x[0]), int(x[1]) ) \
434
for x in [ Path(filepath).stem.split('_') for filepath in image_paths ] if len(x) == 2
435
])
436
if len(a) > 0:
437
idxs = np.ndarray.flatten ( np.argwhere ( a[:,1] != 0 ) )
438
idxs = np.unique ( a[idxs][:,0] )
439
idxs = np.ndarray.flatten ( np.argwhere ( np.array([ x[0] in idxs for x in a ]) == True ) )
440
if len(idxs) > 0:
441
io.log_info ("Found %d images." % (len(idxs)) )
442
img_list = [ (path,) for i,path in enumerate(image_paths) if i not in idxs ]
443
trash_img_list = [ (image_paths[x],) for x in idxs ]
444
return img_list, trash_img_list
445
446
io.log_info ("Nothing found. Possible recover original filenames first.")
447
return [], []
448
449
class FinalLoaderSubprocessor(Subprocessor):
450
class Cli(Subprocessor.Cli):
451
#override
452
def on_initialize(self, client_dict):
453
self.faster = client_dict['faster']
454
455
#override
456
def process_data(self, data):
457
filepath = Path(data[0])
458
459
try:
460
dflimg = DFLIMG.load (filepath)
461
462
if dflimg is None or not dflimg.has_data():
463
self.log_err (f"{filepath.name} is not a dfl image file")
464
return [ 1, [str(filepath)] ]
465
466
bgr = cv2_imread(str(filepath))
467
if bgr is None:
468
raise Exception ("Unable to load %s" % (filepath.name) )
469
470
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
471
if self.faster:
472
source_rect = dflimg.get_source_rect()
473
sharpness = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32))
474
else:
475
face_mask = LandmarksProcessor.get_image_hull_mask (gray.shape, dflimg.get_landmarks())
476
sharpness = estimate_sharpness( (gray[...,None]*face_mask).astype(np.uint8) )
477
478
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] )
479
480
hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
481
except Exception as e:
482
self.log_err (e)
483
return [ 1, [str(filepath)] ]
484
485
return [ 0, [str(filepath), sharpness, hist, yaw, pitch ] ]
486
487
#override
488
def get_data_name (self, data):
489
#return string identificator of your data
490
return data[0]
491
492
#override
493
def __init__(self, img_list, faster ):
494
self.img_list = img_list
495
496
self.faster = faster
497
self.result = []
498
self.result_trash = []
499
500
super().__init__('FinalLoader', FinalLoaderSubprocessor.Cli, 60)
501
502
#override
503
def on_clients_initialized(self):
504
io.progress_bar ("Loading", len (self.img_list))
505
506
#override
507
def on_clients_finalized(self):
508
io.progress_bar_close()
509
510
#override
511
def process_info_generator(self):
512
cpu_count = min(multiprocessing.cpu_count(), 8)
513
io.log_info(f'Running on {cpu_count} CPUs')
514
515
for i in range(cpu_count):
516
yield 'CPU%d' % (i), {}, {'faster': self.faster}
517
518
#override
519
def get_data(self, host_dict):
520
if len (self.img_list) > 0:
521
return [self.img_list.pop(0)]
522
523
return None
524
525
#override
526
def on_data_return (self, host_dict, data):
527
self.img_list.insert(0, data[0])
528
529
#override
530
def on_result (self, host_dict, data, result):
531
if result[0] == 0:
532
self.result.append (result[1])
533
else:
534
self.result_trash.append (result[1])
535
io.progress_bar_inc(1)
536
537
#override
538
def get_result(self):
539
return self.result, self.result_trash
540
541
class FinalHistDissimSubprocessor(Subprocessor):
542
class Cli(Subprocessor.Cli):
543
#override
544
def process_data(self, data):
545
idx, pitch_yaw_img_list = data
546
547
for p in range ( len(pitch_yaw_img_list) ):
548
549
img_list = pitch_yaw_img_list[p]
550
if img_list is not None:
551
for i in range( len(img_list) ):
552
score_total = 0
553
for j in range( len(img_list) ):
554
if i == j:
555
continue
556
score_total += cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA)
557
img_list[i][3] = score_total
558
559
pitch_yaw_img_list[p] = sorted(img_list, key=operator.itemgetter(3), reverse=True)
560
561
return idx, pitch_yaw_img_list
562
563
#override
564
def get_data_name (self, data):
565
return "Bunch of images"
566
567
#override
568
def __init__(self, pitch_yaw_sample_list ):
569
self.pitch_yaw_sample_list = pitch_yaw_sample_list
570
self.pitch_yaw_sample_list_len = len(pitch_yaw_sample_list)
571
572
self.pitch_yaw_sample_list_idxs = [ i for i in range(self.pitch_yaw_sample_list_len) if self.pitch_yaw_sample_list[i] is not None ]
573
self.result = [ None for _ in range(self.pitch_yaw_sample_list_len) ]
574
super().__init__('FinalHistDissimSubprocessor', FinalHistDissimSubprocessor.Cli)
575
576
#override
577
def process_info_generator(self):
578
cpu_count = min(multiprocessing.cpu_count(), 8)
579
io.log_info(f'Running on {cpu_count} CPUs')
580
for i in range(cpu_count):
581
yield 'CPU%d' % (i), {}, {}
582
583
#override
584
def on_clients_initialized(self):
585
io.progress_bar ("Sort by hist-dissim", len(self.pitch_yaw_sample_list_idxs) )
586
587
#override
588
def on_clients_finalized(self):
589
io.progress_bar_close()
590
591
#override
592
def get_data(self, host_dict):
593
if len (self.pitch_yaw_sample_list_idxs) > 0:
594
idx = self.pitch_yaw_sample_list_idxs.pop(0)
595
596
return idx, self.pitch_yaw_sample_list[idx]
597
return None
598
599
#override
600
def on_data_return (self, host_dict, data):
601
self.pitch_yaw_sample_list_idxs.insert(0, data[0])
602
603
#override
604
def on_result (self, host_dict, data, result):
605
idx, yaws_sample_list = data
606
self.result[idx] = yaws_sample_list
607
io.progress_bar_inc(1)
608
609
#override
610
def get_result(self):
611
return self.result
612
613
def sort_best_faster(input_path):
614
return sort_best(input_path, faster=True)
615
616
def sort_best(input_path, faster=False):
617
target_count = io.input_int ("Target number of faces?", 2000)
618
619
io.log_info ("Performing sort by best faces.")
620
if faster:
621
io.log_info("Using faster algorithm. Faces will be sorted by source-rect-area instead of blur.")
622
623
img_list, trash_img_list = FinalLoaderSubprocessor( pathex.get_image_paths(input_path), faster ).run()
624
final_img_list = []
625
626
grads = 128
627
imgs_per_grad = round (target_count / grads)
628
629
#instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2
630
grads_space = np.linspace (-1.2, 1.2,grads)
631
632
yaws_sample_list = [None]*grads
633
for g in io.progress_bar_generator ( range(grads), "Sort by yaw"):
634
yaw = grads_space[g]
635
next_yaw = grads_space[g+1] if g < grads-1 else yaw
636
637
yaw_samples = []
638
for img in img_list:
639
s_yaw = -img[3]
640
if (g == 0 and s_yaw < next_yaw) or \
641
(g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
642
(g == grads-1 and s_yaw >= yaw):
643
yaw_samples += [ img ]
644
if len(yaw_samples) > 0:
645
yaws_sample_list[g] = yaw_samples
646
647
total_lack = 0
648
for g in io.progress_bar_generator ( range(grads), ""):
649
img_list = yaws_sample_list[g]
650
img_list_len = len(img_list) if img_list is not None else 0
651
652
lack = imgs_per_grad - img_list_len
653
total_lack += max(lack, 0)
654
655
imgs_per_grad += total_lack // grads
656
657
658
sharpned_imgs_per_grad = imgs_per_grad*10
659
for g in io.progress_bar_generator ( range (grads), "Sort by blur"):
660
img_list = yaws_sample_list[g]
661
if img_list is None:
662
continue
663
664
img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True)
665
666
if len(img_list) > sharpned_imgs_per_grad:
667
trash_img_list += img_list[sharpned_imgs_per_grad:]
668
img_list = img_list[0:sharpned_imgs_per_grad]
669
670
yaws_sample_list[g] = img_list
671
672
673
yaw_pitch_sample_list = [None]*grads
674
pitch_grads = imgs_per_grad
675
676
for g in io.progress_bar_generator ( range (grads), "Sort by pitch"):
677
img_list = yaws_sample_list[g]
678
if img_list is None:
679
continue
680
681
pitch_sample_list = [None]*pitch_grads
682
683
grads_space = np.linspace (-math.pi / 2,math.pi / 2, pitch_grads )
684
685
for pg in range (pitch_grads):
686
687
pitch = grads_space[pg]
688
next_pitch = grads_space[pg+1] if pg < pitch_grads-1 else pitch
689
690
pitch_samples = []
691
for img in img_list:
692
s_pitch = img[4]
693
if (pg == 0 and s_pitch < next_pitch) or \
694
(pg < pitch_grads-1 and s_pitch >= pitch and s_pitch < next_pitch) or \
695
(pg == pitch_grads-1 and s_pitch >= pitch):
696
pitch_samples += [ img ]
697
698
if len(pitch_samples) > 0:
699
pitch_sample_list[pg] = pitch_samples
700
yaw_pitch_sample_list[g] = pitch_sample_list
701
702
yaw_pitch_sample_list = FinalHistDissimSubprocessor(yaw_pitch_sample_list).run()
703
704
for g in io.progress_bar_generator (range (grads), "Fetching the best"):
705
pitch_sample_list = yaw_pitch_sample_list[g]
706
if pitch_sample_list is None:
707
continue
708
709
n = imgs_per_grad
710
711
while n > 0:
712
n_prev = n
713
for pg in range(pitch_grads):
714
img_list = pitch_sample_list[pg]
715
if img_list is None:
716
continue
717
final_img_list += [ img_list.pop(0) ]
718
if len(img_list) == 0:
719
pitch_sample_list[pg] = None
720
n -= 1
721
if n == 0:
722
break
723
if n_prev == n:
724
break
725
726
for pg in range(pitch_grads):
727
img_list = pitch_sample_list[pg]
728
if img_list is None:
729
continue
730
trash_img_list += img_list
731
732
return final_img_list, trash_img_list
733
734
"""
735
def sort_by_vggface(input_path):
736
io.log_info ("Sorting by face similarity using VGGFace model...")
737
738
model = VGGFace()
739
740
final_img_list = []
741
trash_img_list = []
742
743
image_paths = pathex.get_image_paths(input_path)
744
img_list = [ (x,) for x in image_paths ]
745
img_list_len = len(img_list)
746
img_list_range = [*range(img_list_len)]
747
748
feats = [None]*img_list_len
749
for i in io.progress_bar_generator(img_list_range, "Loading"):
750
img = cv2_imread( img_list[i][0] ).astype(np.float32)
751
img = imagelib.normalize_channels (img, 3)
752
img = cv2.resize (img, (224,224) )
753
img = img[..., ::-1]
754
img[..., 0] -= 93.5940
755
img[..., 1] -= 104.7624
756
img[..., 2] -= 129.1863
757
feats[i] = model.predict( img[None,...] )[0]
758
759
tmp = np.zeros( (img_list_len,) )
760
float_inf = float("inf")
761
for i in io.progress_bar_generator ( range(img_list_len-1), "Sorting" ):
762
i_feat = feats[i]
763
764
for j in img_list_range:
765
tmp[j] = npla.norm(i_feat-feats[j]) if j >= i+1 else float_inf
766
767
idx = np.argmin(tmp)
768
769
img_list[i+1], img_list[idx] = img_list[idx], img_list[i+1]
770
feats[i+1], feats[idx] = feats[idx], feats[i+1]
771
772
return img_list, trash_img_list
773
"""
774
775
def sort_by_absdiff(input_path):
776
io.log_info ("Sorting by absolute difference...")
777
778
is_sim = io.input_bool ("Sort by similar?", True, help_message="Otherwise sort by dissimilar.")
779
780
from core.leras import nn
781
782
device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
783
nn.initialize( device_config=device_config, data_format="NHWC" )
784
tf = nn.tf
785
786
image_paths = pathex.get_image_paths(input_path)
787
image_paths_len = len(image_paths)
788
789
batch_size = 512
790
batch_size_remain = image_paths_len % batch_size
791
792
i_t = tf.placeholder (tf.float32, (None,None,None,None) )
793
j_t = tf.placeholder (tf.float32, (None,None,None,None) )
794
795
outputs_full = []
796
outputs_remain = []
797
798
for i in range(batch_size):
799
diff_t = tf.reduce_sum( tf.abs(i_t-j_t[i]), axis=[1,2,3] )
800
outputs_full.append(diff_t)
801
if i < batch_size_remain:
802
outputs_remain.append(diff_t)
803
804
def func_bs_full(i,j):
805
return nn.tf_sess.run (outputs_full, feed_dict={i_t:i,j_t:j})
806
807
def func_bs_remain(i,j):
808
return nn.tf_sess.run (outputs_remain, feed_dict={i_t:i,j_t:j})
809
810
import h5py
811
db_file_path = Path(tempfile.gettempdir()) / 'sort_cache.hdf5'
812
db_file = h5py.File( str(db_file_path), "w")
813
db = db_file.create_dataset("results", (image_paths_len,image_paths_len), compression="gzip")
814
815
pg_len = image_paths_len // batch_size
816
if batch_size_remain != 0:
817
pg_len += 1
818
819
pg_len = int( ( pg_len*pg_len - pg_len ) / 2 + pg_len )
820
821
io.progress_bar ("Computing", pg_len)
822
j=0
823
while j < image_paths_len:
824
j_images = [ cv2_imread(x) for x in image_paths[j:j+batch_size] ]
825
j_images_len = len(j_images)
826
827
func = func_bs_remain if image_paths_len-j < batch_size else func_bs_full
828
829
i=0
830
while i < image_paths_len:
831
if i >= j:
832
i_images = [ cv2_imread(x) for x in image_paths[i:i+batch_size] ]
833
i_images_len = len(i_images)
834
result = func (i_images,j_images)
835
db[j:j+j_images_len,i:i+i_images_len] = np.array(result)
836
io.progress_bar_inc(1)
837
838
i += batch_size
839
db_file.flush()
840
j += batch_size
841
842
io.progress_bar_close()
843
844
next_id = 0
845
sorted = [next_id]
846
for i in io.progress_bar_generator ( range(image_paths_len-1), "Sorting" ):
847
id_ar = np.concatenate ( [ db[:next_id,next_id], db[next_id,next_id:] ] )
848
id_ar = np.argsort(id_ar)
849
850
851
next_id = np.setdiff1d(id_ar, sorted, True)[ 0 if is_sim else -1]
852
sorted += [next_id]
853
db_file.close()
854
db_file_path.unlink()
855
856
img_list = [ (image_paths[x],) for x in sorted]
857
return img_list, []
858
859
def final_process(input_path, img_list, trash_img_list):
860
if len(trash_img_list) != 0:
861
parent_input_path = input_path.parent
862
trash_path = parent_input_path / (input_path.stem + '_trash')
863
trash_path.mkdir (exist_ok=True)
864
865
io.log_info ("Trashing %d items to %s" % ( len(trash_img_list), str(trash_path) ) )
866
867
for filename in pathex.get_image_paths(trash_path):
868
Path(filename).unlink()
869
870
for i in io.progress_bar_generator( range(len(trash_img_list)), "Moving trash", leave=False):
871
src = Path (trash_img_list[i][0])
872
dst = trash_path / src.name
873
try:
874
src.rename (dst)
875
except:
876
io.log_info ('fail to trashing %s' % (src.name) )
877
878
io.log_info ("")
879
880
if len(img_list) != 0:
881
for i in io.progress_bar_generator( [*range(len(img_list))], "Renaming", leave=False):
882
src = Path (img_list[i][0])
883
dst = input_path / ('%.5d_%s' % (i, src.name ))
884
try:
885
src.rename (dst)
886
except:
887
io.log_info ('fail to rename %s' % (src.name) )
888
889
for i in io.progress_bar_generator( [*range(len(img_list))], "Renaming"):
890
src = Path (img_list[i][0])
891
src = input_path / ('%.5d_%s' % (i, src.name))
892
dst = input_path / ('%.5d%s' % (i, src.suffix))
893
try:
894
src.rename (dst)
895
except:
896
io.log_info ('fail to rename %s' % (src.name) )
897
898
sort_func_methods = {
899
'blur': ("blur", sort_by_blur),
900
'motion-blur': ("motion_blur", sort_by_motion_blur),
901
'face-yaw': ("face yaw direction", sort_by_face_yaw),
902
'face-pitch': ("face pitch direction", sort_by_face_pitch),
903
'face-source-rect-size' : ("face rect size in source image", sort_by_face_source_rect_size),
904
'hist': ("histogram similarity", sort_by_hist),
905
'hist-dissim': ("histogram dissimilarity", sort_by_hist_dissim),
906
'brightness': ("brightness", sort_by_brightness),
907
'hue': ("hue", sort_by_hue),
908
'black': ("amount of black pixels", sort_by_black),
909
'origname': ("original filename", sort_by_origname),
910
'oneface': ("one face in image", sort_by_oneface_in_image),
911
'absdiff': ("absolute pixel difference", sort_by_absdiff),
912
'final': ("best faces", sort_best),
913
'final-fast': ("best faces faster", sort_best_faster),
914
}
915
916
def main (input_path, sort_by_method=None):
917
io.log_info ("Running sort tool.\r\n")
918
919
if sort_by_method is None:
920
io.log_info(f"Choose sorting method:")
921
922
key_list = list(sort_func_methods.keys())
923
for i, key in enumerate(key_list):
924
desc, func = sort_func_methods[key]
925
io.log_info(f"[{i}] {desc}")
926
927
io.log_info("")
928
id = io.input_int("", 5, valid_list=[*range(len(key_list))] )
929
930
sort_by_method = key_list[id]
931
else:
932
sort_by_method = sort_by_method.lower()
933
934
desc, func = sort_func_methods[sort_by_method]
935
img_list, trash_img_list = func(input_path)
936
937
final_process (input_path, img_list, trash_img_list)
938
939