Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/joblib/SubprocessGenerator.py
628 views
1
import multiprocessing
2
import queue as Queue
3
import threading
4
import time
5
6
7
class SubprocessGenerator(object):
8
9
@staticmethod
10
def launch_thread(generator):
11
generator._start()
12
13
@staticmethod
14
def start_in_parallel( generator_list ):
15
"""
16
Start list of generators in parallel
17
"""
18
for generator in generator_list:
19
thread = threading.Thread(target=SubprocessGenerator.launch_thread, args=(generator,) )
20
thread.daemon = True
21
thread.start()
22
23
while not all ([generator._is_started() for generator in generator_list]):
24
time.sleep(0.005)
25
26
def __init__(self, generator_func, user_param=None, prefetch=2, start_now=True):
27
super().__init__()
28
self.prefetch = prefetch
29
self.generator_func = generator_func
30
self.user_param = user_param
31
self.sc_queue = multiprocessing.Queue()
32
self.cs_queue = multiprocessing.Queue()
33
self.p = None
34
if start_now:
35
self._start()
36
37
def _start(self):
38
if self.p == None:
39
user_param = self.user_param
40
self.user_param = None
41
p = multiprocessing.Process(target=self.process_func, args=(user_param,) )
42
p.daemon = True
43
p.start()
44
self.p = p
45
46
def _is_started(self):
47
return self.p is not None
48
49
def process_func(self, user_param):
50
self.generator_func = self.generator_func(user_param)
51
while True:
52
while self.prefetch > -1:
53
try:
54
gen_data = next (self.generator_func)
55
except StopIteration:
56
self.cs_queue.put (None)
57
return
58
self.cs_queue.put (gen_data)
59
self.prefetch -= 1
60
self.sc_queue.get()
61
self.prefetch += 1
62
63
def __iter__(self):
64
return self
65
66
def __getstate__(self):
67
self_dict = self.__dict__.copy()
68
del self_dict['p']
69
return self_dict
70
71
def __next__(self):
72
self._start()
73
gen_data = self.cs_queue.get()
74
if gen_data is None:
75
self.p.terminate()
76
self.p.join()
77
raise StopIteration()
78
self.sc_queue.put (1)
79
return gen_data
80
81