Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/initializers/CA.py
628 views
1
import multiprocessing
2
from core.joblib import Subprocessor
3
import numpy as np
4
5
class CAInitializerSubprocessor(Subprocessor):
6
@staticmethod
7
def generate(shape, dtype=np.float32, eps_std=0.05):
8
"""
9
Super fast implementation of Convolution Aware Initialization for 4D shapes
10
Convolution Aware Initialization https://arxiv.org/abs/1702.06295
11
"""
12
if len(shape) != 4:
13
raise ValueError("only shape with rank 4 supported.")
14
15
row, column, stack_size, filters_size = shape
16
17
fan_in = stack_size * (row * column)
18
19
kernel_shape = (row, column)
20
21
kernel_fft_shape = np.fft.rfft2(np.zeros(kernel_shape)).shape
22
23
basis_size = np.prod(kernel_fft_shape)
24
if basis_size == 1:
25
x = np.random.normal( 0.0, eps_std, (filters_size, stack_size, basis_size) )
26
else:
27
nbb = stack_size // basis_size + 1
28
x = np.random.normal(0.0, 1.0, (filters_size, nbb, basis_size, basis_size))
29
x = x + np.transpose(x, (0,1,3,2) ) * (1-np.eye(basis_size))
30
u, _, v = np.linalg.svd(x)
31
x = np.transpose(u, (0,1,3,2) )
32
x = np.reshape(x, (filters_size, -1, basis_size) )
33
x = x[:,:stack_size,:]
34
35
x = np.reshape(x, ( (filters_size,stack_size,) + kernel_fft_shape ) )
36
37
x = np.fft.irfft2( x, kernel_shape ) \
38
+ np.random.normal(0, eps_std, (filters_size,stack_size,)+kernel_shape)
39
40
x = x * np.sqrt( (2/fan_in) / np.var(x) )
41
x = np.transpose( x, (2, 3, 1, 0) )
42
return x.astype(dtype)
43
44
class Cli(Subprocessor.Cli):
45
#override
46
def process_data(self, data):
47
idx, shape, dtype = data
48
weights = CAInitializerSubprocessor.generate (shape, dtype)
49
return idx, weights
50
51
#override
52
def __init__(self, data_list):
53
self.data_list = data_list
54
self.data_list_idxs = [*range(len(data_list))]
55
self.result = [None]*len(data_list)
56
super().__init__('CAInitializerSubprocessor', CAInitializerSubprocessor.Cli)
57
58
#override
59
def process_info_generator(self):
60
for i in range( min(multiprocessing.cpu_count(), len(self.data_list)) ):
61
yield 'CPU%d' % (i), {}, {}
62
63
#override
64
def get_data(self, host_dict):
65
if len (self.data_list_idxs) > 0:
66
idx = self.data_list_idxs.pop(0)
67
shape, dtype = self.data_list[idx]
68
return idx, shape, dtype
69
return None
70
71
#override
72
def on_data_return (self, host_dict, data):
73
self.data_list_idxs.insert(0, data)
74
75
#override
76
def on_result (self, host_dict, data, result):
77
idx, weights = result
78
self.result[idx] = weights
79
80
#override
81
def get_result(self):
82
return self.result
83
84