Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/layers/Saveable.py
628 views
1
import pickle
2
from pathlib import Path
3
from core import pathex
4
import numpy as np
5
6
from core.leras import nn
7
8
tf = nn.tf
9
10
class Saveable():
11
def __init__(self, name=None):
12
self.name = name
13
14
#override
15
def get_weights(self):
16
#return tf tensors that should be initialized/loaded/saved
17
return []
18
19
#override
20
def get_weights_np(self):
21
weights = self.get_weights()
22
if len(weights) == 0:
23
return []
24
return nn.tf_sess.run (weights)
25
26
def set_weights(self, new_weights):
27
weights = self.get_weights()
28
if len(weights) != len(new_weights):
29
raise ValueError ('len of lists mismatch')
30
31
tuples = []
32
for w, new_w in zip(weights, new_weights):
33
34
if len(w.shape) != new_w.shape:
35
new_w = new_w.reshape(w.shape)
36
37
tuples.append ( (w, new_w) )
38
39
nn.batch_set_value (tuples)
40
41
def save_weights(self, filename, force_dtype=None):
42
d = {}
43
weights = self.get_weights()
44
45
if self.name is None:
46
raise Exception("name must be defined.")
47
48
name = self.name
49
50
for w in weights:
51
w_val = nn.tf_sess.run (w).copy()
52
w_name_split = w.name.split('/', 1)
53
if name != w_name_split[0]:
54
raise Exception("weight first name != Saveable.name")
55
56
if force_dtype is not None:
57
w_val = w_val.astype(force_dtype)
58
59
d[ w_name_split[1] ] = w_val
60
61
d_dumped = pickle.dumps (d, 4)
62
pathex.write_bytes_safe ( Path(filename), d_dumped )
63
64
def load_weights(self, filename):
65
"""
66
returns True if file exists
67
"""
68
filepath = Path(filename)
69
if filepath.exists():
70
result = True
71
d_dumped = filepath.read_bytes()
72
d = pickle.loads(d_dumped)
73
else:
74
return False
75
76
weights = self.get_weights()
77
78
if self.name is None:
79
raise Exception("name must be defined.")
80
81
try:
82
tuples = []
83
for w in weights:
84
w_name_split = w.name.split('/')
85
if self.name != w_name_split[0]:
86
raise Exception("weight first name != Saveable.name")
87
88
sub_w_name = "/".join(w_name_split[1:])
89
90
w_val = d.get(sub_w_name, None)
91
92
if w_val is None:
93
#io.log_err(f"Weight {w.name} was not loaded from file {filename}")
94
tuples.append ( (w, w.initializer) )
95
else:
96
w_val = np.reshape( w_val, w.shape.as_list() )
97
tuples.append ( (w, w_val) )
98
99
nn.batch_set_value(tuples)
100
except:
101
return False
102
103
return True
104
105
def init_weights(self):
106
nn.init_weights(self.get_weights())
107
108
nn.Saveable = Saveable
109
110