Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/mainscripts/Trainer.py
628 views
1
import os
2
import sys
3
import traceback
4
import queue
5
import threading
6
import time
7
import numpy as np
8
import itertools
9
from pathlib import Path
10
from core import pathex
11
from core import imagelib
12
import cv2
13
import models
14
from core.interact import interact as io
15
16
def trainerThread (s2c, c2s, e,
17
model_class_name = None,
18
saved_models_path = None,
19
training_data_src_path = None,
20
training_data_dst_path = None,
21
pretraining_data_path = None,
22
pretrained_model_path = None,
23
no_preview=False,
24
force_model_name=None,
25
force_gpu_idxs=None,
26
cpu_only=None,
27
silent_start=False,
28
execute_programs = None,
29
debug=False,
30
**kwargs):
31
while True:
32
try:
33
start_time = time.time()
34
35
save_interval_min = 25
36
37
if not training_data_src_path.exists():
38
training_data_src_path.mkdir(exist_ok=True, parents=True)
39
40
if not training_data_dst_path.exists():
41
training_data_dst_path.mkdir(exist_ok=True, parents=True)
42
43
if not saved_models_path.exists():
44
saved_models_path.mkdir(exist_ok=True, parents=True)
45
46
model = models.import_model(model_class_name)(
47
is_training=True,
48
saved_models_path=saved_models_path,
49
training_data_src_path=training_data_src_path,
50
training_data_dst_path=training_data_dst_path,
51
pretraining_data_path=pretraining_data_path,
52
pretrained_model_path=pretrained_model_path,
53
no_preview=no_preview,
54
force_model_name=force_model_name,
55
force_gpu_idxs=force_gpu_idxs,
56
cpu_only=cpu_only,
57
silent_start=silent_start,
58
debug=debug)
59
60
is_reached_goal = model.is_reached_iter_goal()
61
62
shared_state = { 'after_save' : False }
63
loss_string = ""
64
save_iter = model.get_iter()
65
def model_save():
66
if not debug and not is_reached_goal:
67
io.log_info ("Saving....", end='\r')
68
model.save()
69
shared_state['after_save'] = True
70
71
def model_backup():
72
if not debug and not is_reached_goal:
73
model.create_backup()
74
75
def send_preview():
76
if not debug:
77
previews = model.get_previews()
78
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
79
else:
80
previews = [( 'debug, press update for new', model.debug_one_iter())]
81
c2s.put ( {'op':'show', 'previews': previews} )
82
e.set() #Set the GUI Thread as Ready
83
84
if model.get_target_iter() != 0:
85
if is_reached_goal:
86
io.log_info('Model already trained to target iteration. You can use preview.')
87
else:
88
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) )
89
else:
90
io.log_info('Starting. Press "Enter" to stop training and save model.')
91
92
last_save_time = time.time()
93
94
execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ]
95
96
for i in itertools.count(0,1):
97
if not debug:
98
cur_time = time.time()
99
100
for x in execute_programs:
101
prog_time, prog, last_time = x
102
exec_prog = False
103
if prog_time > 0 and (cur_time - start_time) >= prog_time:
104
x[0] = 0
105
exec_prog = True
106
elif prog_time < 0 and (cur_time - last_time) >= -prog_time:
107
x[2] = cur_time
108
exec_prog = True
109
110
if exec_prog:
111
try:
112
exec(prog)
113
except Exception as e:
114
print("Unable to execute program: %s" % (prog) )
115
116
if not is_reached_goal:
117
118
if model.get_iter() == 0:
119
io.log_info("")
120
io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.")
121
io.log_info("")
122
123
if sys.platform[0:3] == 'win':
124
io.log_info("!!!")
125
io.log_info("Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.")
126
io.log_info("https://i.imgur.com/B7cmDCB.jpg")
127
io.log_info("!!!")
128
129
iter, iter_time = model.train_one_iter()
130
131
loss_history = model.get_loss_history()
132
time_str = time.strftime("[%H:%M:%S]")
133
if iter_time >= 10:
134
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) )
135
else:
136
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) )
137
138
if shared_state['after_save']:
139
shared_state['after_save'] = False
140
141
mean_loss = np.mean ( loss_history[save_iter:iter], axis=0)
142
143
for loss_value in mean_loss:
144
loss_string += "[%.4f]" % (loss_value)
145
146
io.log_info (loss_string)
147
148
save_iter = iter
149
else:
150
for loss_value in loss_history[-1]:
151
loss_string += "[%.4f]" % (loss_value)
152
153
if io.is_colab():
154
io.log_info ('\r' + loss_string, end='')
155
else:
156
io.log_info (loss_string, end='\r')
157
158
if model.get_iter() == 1:
159
model_save()
160
161
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
162
io.log_info ('Reached target iteration.')
163
model_save()
164
is_reached_goal = True
165
io.log_info ('You can use preview now.')
166
167
need_save = False
168
while time.time() - last_save_time >= save_interval_min*60:
169
last_save_time += save_interval_min*60
170
need_save = True
171
172
if not is_reached_goal and need_save:
173
model_save()
174
send_preview()
175
176
if i==0:
177
if is_reached_goal:
178
model.pass_one_iter()
179
send_preview()
180
181
if debug:
182
time.sleep(0.005)
183
184
while not s2c.empty():
185
input = s2c.get()
186
op = input['op']
187
if op == 'save':
188
model_save()
189
elif op == 'backup':
190
model_backup()
191
elif op == 'preview':
192
if is_reached_goal:
193
model.pass_one_iter()
194
send_preview()
195
elif op == 'close':
196
model_save()
197
i = -1
198
break
199
200
if i == -1:
201
break
202
203
204
205
model.finalize()
206
207
except Exception as e:
208
print ('Error: %s' % (str(e)))
209
traceback.print_exc()
210
break
211
c2s.put ( {'op':'close'} )
212
213
214
215
def main(**kwargs):
216
io.log_info ("Running trainer.\r\n")
217
218
no_preview = kwargs.get('no_preview', False)
219
220
s2c = queue.Queue()
221
c2s = queue.Queue()
222
223
e = threading.Event()
224
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e), kwargs=kwargs )
225
thread.start()
226
227
e.wait() #Wait for inital load to occur.
228
229
if no_preview:
230
while True:
231
if not c2s.empty():
232
input = c2s.get()
233
op = input.get('op','')
234
if op == 'close':
235
break
236
try:
237
io.process_messages(0.1)
238
except KeyboardInterrupt:
239
s2c.put ( {'op': 'close'} )
240
else:
241
wnd_name = "Training preview"
242
io.named_window(wnd_name)
243
io.capture_keys(wnd_name)
244
245
previews = None
246
loss_history = None
247
selected_preview = 0
248
update_preview = False
249
is_showing = False
250
is_waiting_preview = False
251
show_last_history_iters_count = 0
252
iter = 0
253
while True:
254
if not c2s.empty():
255
input = c2s.get()
256
op = input['op']
257
if op == 'show':
258
is_waiting_preview = False
259
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
260
previews = input['previews'] if 'previews' in input.keys() else None
261
iter = input['iter'] if 'iter' in input.keys() else 0
262
if previews is not None:
263
max_w = 0
264
max_h = 0
265
for (preview_name, preview_rgb) in previews:
266
(h, w, c) = preview_rgb.shape
267
max_h = max (max_h, h)
268
max_w = max (max_w, w)
269
270
max_size = 800
271
if max_h > max_size:
272
max_w = int( max_w / (max_h / max_size) )
273
max_h = max_size
274
275
#make all previews size equal
276
for preview in previews[:]:
277
(preview_name, preview_rgb) = preview
278
(h, w, c) = preview_rgb.shape
279
if h != max_h or w != max_w:
280
previews.remove(preview)
281
previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) )
282
selected_preview = selected_preview % len(previews)
283
update_preview = True
284
elif op == 'close':
285
break
286
287
if update_preview:
288
update_preview = False
289
290
selected_preview_name = previews[selected_preview][0]
291
selected_preview_rgb = previews[selected_preview][1]
292
(h,w,c) = selected_preview_rgb.shape
293
294
# HEAD
295
head_lines = [
296
'[s]:save [b]:backup [enter]:exit',
297
'[p]:update [space]:next preview [l]:change history range',
298
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) )
299
]
300
head_line_height = 15
301
head_height = len(head_lines) * head_line_height
302
head = np.ones ( (head_height,w,c) ) * 0.1
303
304
for i in range(0, len(head_lines)):
305
t = i*head_line_height
306
b = (i+1)*head_line_height
307
head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c )
308
309
final = head
310
311
if loss_history is not None:
312
if show_last_history_iters_count == 0:
313
loss_history_to_show = loss_history
314
else:
315
loss_history_to_show = loss_history[-show_last_history_iters_count:]
316
317
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c)
318
final = np.concatenate ( [final, lh_img], axis=0 )
319
320
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
321
final = np.clip(final, 0, 1)
322
323
io.show_image( wnd_name, (final*255).astype(np.uint8) )
324
is_showing = True
325
326
key_events = io.get_key_events(wnd_name)
327
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
328
329
if key == ord('\n') or key == ord('\r'):
330
s2c.put ( {'op': 'close'} )
331
elif key == ord('s'):
332
s2c.put ( {'op': 'save'} )
333
elif key == ord('b'):
334
s2c.put ( {'op': 'backup'} )
335
elif key == ord('p'):
336
if not is_waiting_preview:
337
is_waiting_preview = True
338
s2c.put ( {'op': 'preview'} )
339
elif key == ord('l'):
340
if show_last_history_iters_count == 0:
341
show_last_history_iters_count = 5000
342
elif show_last_history_iters_count == 5000:
343
show_last_history_iters_count = 10000
344
elif show_last_history_iters_count == 10000:
345
show_last_history_iters_count = 50000
346
elif show_last_history_iters_count == 50000:
347
show_last_history_iters_count = 100000
348
elif show_last_history_iters_count == 100000:
349
show_last_history_iters_count = 0
350
update_preview = True
351
elif key == ord(' '):
352
selected_preview = (selected_preview + 1) % len(previews)
353
update_preview = True
354
355
try:
356
io.process_messages(0.1)
357
except KeyboardInterrupt:
358
s2c.put ( {'op': 'close'} )
359
360
io.destroy_all_windows()
361