Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/models/ModelBase.py
628 views
1
import types
2
import numpy as np
3
from core.interact import interact as io
4
from core.leras import nn
5
tf = nn.tf
6
7
class ModelBase(nn.Saveable):
8
def __init__(self, *args, name=None, **kwargs):
9
super().__init__(name=name)
10
self.layers = []
11
self.layers_by_name = {}
12
self.built = False
13
self.args = args
14
self.kwargs = kwargs
15
self.run_placeholders = None
16
17
def _build_sub(self, layer, name):
18
if isinstance (layer, list):
19
for i,sublayer in enumerate(layer):
20
self._build_sub(sublayer, f"{name}_{i}")
21
elif isinstance (layer, dict):
22
for subname in layer.keys():
23
sublayer = layer[subname]
24
self._build_sub(sublayer, f"{name}_{subname}")
25
elif isinstance (layer, nn.LayerBase) or \
26
isinstance (layer, ModelBase):
27
28
if layer.name is None:
29
layer.name = name
30
31
if isinstance (layer, nn.LayerBase):
32
with tf.variable_scope(layer.name):
33
layer.build_weights()
34
elif isinstance (layer, ModelBase):
35
layer.build()
36
37
self.layers.append (layer)
38
self.layers_by_name[layer.name] = layer
39
40
def xor_list(self, lst1, lst2):
41
return [value for value in lst1+lst2 if (value not in lst1) or (value not in lst2) ]
42
43
def build(self):
44
with tf.variable_scope(self.name):
45
46
current_vars = []
47
generator = None
48
while True:
49
50
if generator is None:
51
generator = self.on_build(*self.args, **self.kwargs)
52
if not isinstance(generator, types.GeneratorType):
53
generator = None
54
55
if generator is not None:
56
try:
57
next(generator)
58
except StopIteration:
59
generator = None
60
61
v = vars(self)
62
new_vars = self.xor_list (current_vars, list(v.keys()) )
63
64
for name in new_vars:
65
self._build_sub(v[name],name)
66
67
current_vars += new_vars
68
69
if generator is None:
70
break
71
72
self.built = True
73
74
#override
75
def get_weights(self):
76
if not self.built:
77
self.build()
78
79
weights = []
80
for layer in self.layers:
81
weights += layer.get_weights()
82
return weights
83
84
def get_layer_by_name(self, name):
85
return self.layers_by_name.get(name, None)
86
87
def get_layers(self):
88
if not self.built:
89
self.build()
90
layers = []
91
for layer in self.layers:
92
if isinstance (layer, nn.LayerBase):
93
layers.append(layer)
94
else:
95
layers += layer.get_layers()
96
return layers
97
98
#override
99
def on_build(self, *args, **kwargs):
100
"""
101
init model layers here
102
103
return 'yield' if build is not finished
104
therefore dependency models will be initialized
105
"""
106
pass
107
108
#override
109
def forward(self, *args, **kwargs):
110
#flow layers/models/tensors here
111
pass
112
113
def __call__(self, *args, **kwargs):
114
if not self.built:
115
self.build()
116
117
return self.forward(*args, **kwargs)
118
119
# def compute_output_shape(self, shapes):
120
# if not self.built:
121
# self.build()
122
123
# not_list = False
124
# if not isinstance(shapes, list):
125
# not_list = True
126
# shapes = [shapes]
127
128
# with tf.device('/CPU:0'):
129
# # CPU tensors will not impact any performance, only slightly RAM "leakage"
130
# phs = []
131
# for dtype,sh in shapes:
132
# phs += [ tf.placeholder(dtype, sh) ]
133
134
# result = self.__call__(phs[0] if not_list else phs)
135
136
# if not isinstance(result, list):
137
# result = [result]
138
139
# result_shapes = []
140
141
# for t in result:
142
# result_shapes += [ t.shape.as_list() ]
143
144
# return result_shapes[0] if not_list else result_shapes
145
146
def build_for_run(self, shapes_list):
147
if not isinstance(shapes_list, list):
148
raise ValueError("shapes_list must be a list.")
149
150
self.run_placeholders = []
151
for dtype,sh in shapes_list:
152
self.run_placeholders.append ( tf.placeholder(dtype, sh) )
153
154
self.run_output = self.__call__(self.run_placeholders)
155
156
def run (self, inputs):
157
if self.run_placeholders is None:
158
raise Exception ("Model didn't build for run.")
159
160
if len(inputs) != len(self.run_placeholders):
161
raise ValueError("len(inputs) != self.run_placeholders")
162
163
feed_dict = {}
164
for ph, inp in zip(self.run_placeholders, inputs):
165
feed_dict[ph] = inp
166
167
return nn.tf_sess.run ( self.run_output, feed_dict=feed_dict)
168
169
def summary(self):
170
layers = self.get_layers()
171
layers_names = []
172
layers_params = []
173
174
max_len_str = 0
175
max_len_param_str = 0
176
delim_str = "-"
177
178
total_params = 0
179
180
#Get layers names and str lenght for delim
181
for l in layers:
182
if len(str(l))>max_len_str:
183
max_len_str = len(str(l))
184
layers_names+=[str(l).capitalize()]
185
186
#Get params for each layer
187
layers_params = [ int(np.sum(np.prod(w.shape) for w in l.get_weights())) for l in layers ]
188
total_params = np.sum(layers_params)
189
190
#Get str lenght for delim
191
for p in layers_params:
192
if len(str(p))>max_len_param_str:
193
max_len_param_str=len(str(p))
194
195
#Set delim
196
for i in range(max_len_str+max_len_param_str+3):
197
delim_str += "-"
198
199
output = "\n"+delim_str+"\n"
200
201
#Format model name str
202
model_name_str = "| "+self.name.capitalize()
203
len_model_name_str = len(model_name_str)
204
for i in range(len(delim_str)-len_model_name_str):
205
model_name_str+= " " if i!=(len(delim_str)-len_model_name_str-2) else " |"
206
207
output += model_name_str +"\n"
208
output += delim_str +"\n"
209
210
211
#Format layers table
212
for i in range(len(layers_names)):
213
output += delim_str +"\n"
214
215
l_name = layers_names[i]
216
l_param = str(layers_params[i])
217
l_param_str = ""
218
if len(l_name)<=max_len_str:
219
for i in range(max_len_str - len(l_name)):
220
l_name+= " "
221
222
if len(l_param)<=max_len_param_str:
223
for i in range(max_len_param_str - len(l_param)):
224
l_param_str+= " "
225
226
l_param_str += l_param
227
228
229
output +="| "+l_name+"|"+l_param_str+"| \n"
230
231
output += delim_str +"\n"
232
233
#Format sum of params
234
total_params_str = "| Total params count: "+str(total_params)
235
len_total_params_str = len(total_params_str)
236
for i in range(len(delim_str)-len_total_params_str):
237
total_params_str+= " " if i!=(len(delim_str)-len_total_params_str-2) else " |"
238
239
output += total_params_str +"\n"
240
output += delim_str +"\n"
241
242
io.log_info(output)
243
244
nn.ModelBase = ModelBase
245
246