Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/mplib/MPSharedList.py
628 views
1
import multiprocessing
2
import pickle
3
import struct
4
from core.joblib import Subprocessor
5
6
class MPSharedList():
7
"""
8
Provides read-only pickled list of constant objects via shared memory aka 'multiprocessing.Array'
9
Thus no 4GB limit for subprocesses.
10
11
supports list concat via + or sum()
12
"""
13
14
def __init__(self, obj_list):
15
if obj_list is None:
16
self.obj_counts = None
17
self.table_offsets = None
18
self.data_offsets = None
19
self.sh_bs = None
20
else:
21
obj_count, table_offset, data_offset, sh_b = MPSharedList.bake_data(obj_list)
22
23
self.obj_counts = [obj_count]
24
self.table_offsets = [table_offset]
25
self.data_offsets = [data_offset]
26
self.sh_bs = [sh_b]
27
28
def __add__(self, o):
29
if isinstance(o, MPSharedList):
30
m = MPSharedList(None)
31
m.obj_counts = self.obj_counts + o.obj_counts
32
m.table_offsets = self.table_offsets + o.table_offsets
33
m.data_offsets = self.data_offsets + o.data_offsets
34
m.sh_bs = self.sh_bs + o.sh_bs
35
return m
36
elif isinstance(o, int):
37
return self
38
else:
39
raise ValueError(f"MPSharedList object of class {o.__class__} is not supported for __add__ operator.")
40
41
def __radd__(self, o):
42
return self+o
43
44
def __len__(self):
45
return sum(self.obj_counts)
46
47
def __getitem__(self, key):
48
obj_count = sum(self.obj_counts)
49
if key < 0:
50
key = obj_count+key
51
if key < 0 or key >= obj_count:
52
raise ValueError("out of range")
53
54
for i in range(len(self.obj_counts)):
55
56
if key < self.obj_counts[i]:
57
table_offset = self.table_offsets[i]
58
data_offset = self.data_offsets[i]
59
sh_b = self.sh_bs[i]
60
break
61
key -= self.obj_counts[i]
62
63
sh_b = memoryview(sh_b).cast('B')
64
65
offset_start, offset_end = struct.unpack('<QQ', sh_b[ table_offset + key*8 : table_offset + (key+2)*8].tobytes() )
66
67
return pickle.loads( sh_b[ data_offset + offset_start : data_offset + offset_end ].tobytes() )
68
69
def __iter__(self):
70
for i in range(self.__len__()):
71
yield self.__getitem__(i)
72
73
@staticmethod
74
def bake_data(obj_list):
75
if not isinstance(obj_list, list):
76
raise ValueError("MPSharedList: obj_list should be list type.")
77
78
obj_count = len(obj_list)
79
80
if obj_count != 0:
81
obj_pickled_ar = [pickle.dumps(o, 4) for o in obj_list]
82
83
table_offset = 0
84
table_size = (obj_count+1)*8
85
data_offset = table_offset + table_size
86
data_size = sum([len(x) for x in obj_pickled_ar])
87
88
sh_b = multiprocessing.RawArray('B', table_size + data_size)
89
#sh_b[0:8] = struct.pack('<Q', obj_count)
90
sh_b_view = memoryview(sh_b).cast('B')
91
92
offset = 0
93
94
sh_b_table = bytes()
95
offsets = []
96
97
offset = 0
98
for i in range(obj_count):
99
offsets.append(offset)
100
offset += len(obj_pickled_ar[i])
101
offsets.append(offset)
102
103
sh_b_view[table_offset:table_offset+table_size] = struct.pack( '<'+'Q'*len(offsets), *offsets )
104
105
for i, obj_pickled in enumerate(obj_pickled_ar):
106
offset = data_offset+offsets[i]
107
sh_b_view[offset:offset+len(obj_pickled)] = obj_pickled_ar[i]
108
109
return obj_count, table_offset, data_offset, sh_b
110
return 0, 0, 0, None
111
112
113