Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/mplib/__init__.py
628 views
1
from .MPSharedList import MPSharedList
2
import multiprocessing
3
import threading
4
import time
5
6
import numpy as np
7
8
9
class IndexHost():
10
"""
11
Provides random shuffled indexes for multiprocesses
12
"""
13
def __init__(self, indexes_count, rnd_seed=None):
14
self.sq = multiprocessing.Queue()
15
self.cqs = []
16
self.clis = []
17
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,rnd_seed) )
18
self.thread.daemon = True
19
self.thread.start()
20
21
def host_thread(self, indexes_count, rnd_seed):
22
rnd_state = np.random.RandomState(rnd_seed) if rnd_seed is not None else np.random
23
24
idxs = [*range(indexes_count)]
25
shuffle_idxs = []
26
sq = self.sq
27
28
while True:
29
while not sq.empty():
30
obj = sq.get()
31
cq_id, count = obj[0], obj[1]
32
33
result = []
34
for i in range(count):
35
if len(shuffle_idxs) == 0:
36
shuffle_idxs = idxs.copy()
37
rnd_state.shuffle(shuffle_idxs)
38
result.append(shuffle_idxs.pop())
39
self.cqs[cq_id].put (result)
40
41
time.sleep(0.001)
42
43
def create_cli(self):
44
cq = multiprocessing.Queue()
45
self.cqs.append ( cq )
46
cq_id = len(self.cqs)-1
47
return IndexHost.Cli(self.sq, cq, cq_id)
48
49
# disable pickling
50
def __getstate__(self):
51
return dict()
52
def __setstate__(self, d):
53
self.__dict__.update(d)
54
55
class Cli():
56
def __init__(self, sq, cq, cq_id):
57
self.sq = sq
58
self.cq = cq
59
self.cq_id = cq_id
60
61
def multi_get(self, count):
62
self.sq.put ( (self.cq_id,count) )
63
64
while True:
65
if not self.cq.empty():
66
return self.cq.get()
67
time.sleep(0.001)
68
69
class Index2DHost():
70
"""
71
Provides random shuffled indexes for multiprocesses
72
"""
73
def __init__(self, indexes2D):
74
self.sq = multiprocessing.Queue()
75
self.cqs = []
76
self.clis = []
77
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
78
self.thread.daemon = True
79
self.thread.start()
80
81
def host_thread(self, indexes2D):
82
indexes2D_len = len(indexes2D)
83
84
idxs = [*range(indexes2D_len)]
85
idxs_2D = [None]*indexes2D_len
86
shuffle_idxs = []
87
shuffle_idxs_2D = [None]*indexes2D_len
88
for i in range(indexes2D_len):
89
idxs_2D[i] = [*range(len(indexes2D[i]))]
90
shuffle_idxs_2D[i] = []
91
92
#print(idxs)
93
#print(idxs_2D)
94
sq = self.sq
95
96
while True:
97
while not sq.empty():
98
obj = sq.get()
99
cq_id, count = obj[0], obj[1]
100
101
result = []
102
for i in range(count):
103
if len(shuffle_idxs) == 0:
104
shuffle_idxs = idxs.copy()
105
np.random.shuffle(shuffle_idxs)
106
107
idx_1D = shuffle_idxs.pop()
108
109
#print(f'idx_1D = {idx_1D}, len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
110
111
if len(shuffle_idxs_2D[idx_1D]) == 0:
112
shuffle_idxs_2D[idx_1D] = idxs_2D[idx_1D].copy()
113
#print(f'new shuffle_idxs_2d for {idx_1D} = { shuffle_idxs_2D[idx_1D] }')
114
115
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
116
117
np.random.shuffle( shuffle_idxs_2D[idx_1D] )
118
119
idx_2D = shuffle_idxs_2D[idx_1D].pop()
120
121
#print(f'len(shuffle_idxs_2D[idx_1D])= {len(shuffle_idxs_2D[idx_1D])}')
122
123
#print(f'idx_2D = {idx_2D}')
124
125
126
result.append( indexes2D[idx_1D][idx_2D])
127
128
self.cqs[cq_id].put (result)
129
130
time.sleep(0.001)
131
132
def create_cli(self):
133
cq = multiprocessing.Queue()
134
self.cqs.append ( cq )
135
cq_id = len(self.cqs)-1
136
return Index2DHost.Cli(self.sq, cq, cq_id)
137
138
# disable pickling
139
def __getstate__(self):
140
return dict()
141
def __setstate__(self, d):
142
self.__dict__.update(d)
143
144
class Cli():
145
def __init__(self, sq, cq, cq_id):
146
self.sq = sq
147
self.cq = cq
148
self.cq_id = cq_id
149
150
def multi_get(self, count):
151
self.sq.put ( (self.cq_id,count) )
152
153
while True:
154
if not self.cq.empty():
155
return self.cq.get()
156
time.sleep(0.001)
157
158
class ListHost():
159
def __init__(self, list_):
160
self.sq = multiprocessing.Queue()
161
self.cqs = []
162
self.clis = []
163
self.m_list = list_
164
self.thread = threading.Thread(target=self.host_thread)
165
self.thread.daemon = True
166
self.thread.start()
167
168
def host_thread(self):
169
sq = self.sq
170
while True:
171
while not sq.empty():
172
obj = sq.get()
173
cq_id, cmd = obj[0], obj[1]
174
175
if cmd == 0:
176
self.cqs[cq_id].put ( len(self.m_list) )
177
elif cmd == 1:
178
idx = obj[2]
179
item = self.m_list[idx ]
180
self.cqs[cq_id].put ( item )
181
elif cmd == 2:
182
result = []
183
for item in obj[2]:
184
result.append ( self.m_list[item] )
185
self.cqs[cq_id].put ( result )
186
elif cmd == 3:
187
self.m_list.insert(obj[2], obj[3])
188
elif cmd == 4:
189
self.m_list.append(obj[2])
190
elif cmd == 5:
191
self.m_list.extend(obj[2])
192
193
time.sleep(0.005)
194
195
def create_cli(self):
196
cq = multiprocessing.Queue()
197
self.cqs.append ( cq )
198
cq_id = len(self.cqs)-1
199
return ListHost.Cli(self.sq, cq, cq_id)
200
201
def get_list(self):
202
return self.list_
203
204
# disable pickling
205
def __getstate__(self):
206
return dict()
207
def __setstate__(self, d):
208
self.__dict__.update(d)
209
210
class Cli():
211
def __init__(self, sq, cq, cq_id):
212
self.sq = sq
213
self.cq = cq
214
self.cq_id = cq_id
215
216
def __len__(self):
217
self.sq.put ( (self.cq_id,0) )
218
219
while True:
220
if not self.cq.empty():
221
return self.cq.get()
222
time.sleep(0.001)
223
224
def __getitem__(self, key):
225
self.sq.put ( (self.cq_id,1,key) )
226
227
while True:
228
if not self.cq.empty():
229
return self.cq.get()
230
time.sleep(0.001)
231
232
def multi_get(self, keys):
233
self.sq.put ( (self.cq_id,2,keys) )
234
235
while True:
236
if not self.cq.empty():
237
return self.cq.get()
238
time.sleep(0.001)
239
240
def insert(self, index, item):
241
self.sq.put ( (self.cq_id,3,index,item) )
242
243
def append(self, item):
244
self.sq.put ( (self.cq_id,4,item) )
245
246
def extend(self, items):
247
self.sq.put ( (self.cq_id,5,items) )
248
249
250
251
class DictHost():
252
def __init__(self, d, num_users):
253
self.sqs = [ multiprocessing.Queue() for _ in range(num_users) ]
254
self.cqs = [ multiprocessing.Queue() for _ in range(num_users) ]
255
256
self.thread = threading.Thread(target=self.host_thread, args=(d,) )
257
self.thread.daemon = True
258
self.thread.start()
259
260
self.clis = [ DictHostCli(sq,cq) for sq, cq in zip(self.sqs, self.cqs) ]
261
262
def host_thread(self, d):
263
while True:
264
for sq, cq in zip(self.sqs, self.cqs):
265
if not sq.empty():
266
obj = sq.get()
267
cmd = obj[0]
268
if cmd == 0:
269
cq.put (d[ obj[1] ])
270
elif cmd == 1:
271
cq.put ( list(d.keys()) )
272
273
time.sleep(0.005)
274
275
276
def get_cli(self, n_user):
277
return self.clis[n_user]
278
279
# disable pickling
280
def __getstate__(self):
281
return dict()
282
def __setstate__(self, d):
283
self.__dict__.update(d)
284
285
class DictHostCli():
286
def __init__(self, sq, cq):
287
self.sq = sq
288
self.cq = cq
289
290
def __getitem__(self, key):
291
self.sq.put ( (0,key) )
292
293
while True:
294
if not self.cq.empty():
295
return self.cq.get()
296
time.sleep(0.001)
297
298
def keys(self):
299
self.sq.put ( (1,) )
300
while True:
301
if not self.cq.empty():
302
return self.cq.get()
303
time.sleep(0.001)
304
305