Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
iperov
GitHub Repository: iperov/deepfacelab
Path: blob/master/core/leras/nn.py
628 views
1
"""
2
Leras.
3
4
like lighter keras.
5
This is my lightweight neural network library written from scratch
6
based on pure tensorflow without keras.
7
8
Provides:
9
+ full freedom of tensorflow operations without keras model's restrictions
10
+ easy model operations like in PyTorch, but in graph mode (no eager execution)
11
+ convenient and understandable logic
12
13
Reasons why we cannot import tensorflow or any tensorflow.sub modules right here:
14
1) program is changing env variables based on DeviceConfig before import tensorflow
15
2) multiprocesses will import tensorflow every spawn
16
17
NCHW speed up training for 10-20%.
18
"""
19
20
import os
21
import sys
22
import warnings
23
warnings.simplefilter(action='ignore', category=FutureWarning)
24
from pathlib import Path
25
import numpy as np
26
from core.interact import interact as io
27
from .device import Devices
28
29
30
class nn():
31
current_DeviceConfig = None
32
33
tf = None
34
tf_sess = None
35
tf_sess_config = None
36
tf_default_device_name = None
37
38
data_format = None
39
conv2d_ch_axis = None
40
conv2d_spatial_axes = None
41
42
floatx = None
43
44
@staticmethod
45
def initialize(device_config=None, floatx="float32", data_format="NHWC"):
46
47
if nn.tf is None:
48
if device_config is None:
49
device_config = nn.getCurrentDeviceConfig()
50
nn.setCurrentDeviceConfig(device_config)
51
52
# Manipulate environment variables before import tensorflow
53
54
first_run = False
55
if len(device_config.devices) != 0:
56
if sys.platform[0:3] == 'win':
57
# Windows specific env vars
58
if all( [ x.name == device_config.devices[0].name for x in device_config.devices ] ):
59
devices_str = "_" + device_config.devices[0].name.replace(' ','_')
60
else:
61
devices_str = ""
62
for device in device_config.devices:
63
devices_str += "_" + device.name.replace(' ','_')
64
65
compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache' + devices_str)
66
if not compute_cache_path.exists():
67
first_run = True
68
compute_cache_path.mkdir(parents=True, exist_ok=True)
69
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
70
71
if first_run:
72
io.log_info("Caching GPU kernels...")
73
74
import tensorflow
75
76
tf_version = tensorflow.version.VERSION
77
#if tf_version is None:
78
# tf_version = tensorflow.version.GIT_VERSION
79
if tf_version[0] == 'v':
80
tf_version = tf_version[1:]
81
if tf_version[0] == '2':
82
tf = tensorflow.compat.v1
83
else:
84
tf = tensorflow
85
86
import logging
87
# Disable tensorflow warnings
88
tf_logger = logging.getLogger('tensorflow')
89
tf_logger.setLevel(logging.ERROR)
90
91
if tf_version[0] == '2':
92
tf.disable_v2_behavior()
93
nn.tf = tf
94
95
# Initialize framework
96
import core.leras.ops
97
import core.leras.layers
98
import core.leras.initializers
99
import core.leras.optimizers
100
import core.leras.models
101
import core.leras.archis
102
103
# Configure tensorflow session-config
104
if len(device_config.devices) == 0:
105
config = tf.ConfigProto(device_count={'GPU': 0})
106
nn.tf_default_device_name = '/CPU:0'
107
else:
108
nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0'
109
110
config = tf.ConfigProto()
111
config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices])
112
113
config.gpu_options.force_gpu_compatible = True
114
config.gpu_options.allow_growth = True
115
nn.tf_sess_config = config
116
117
if nn.tf_sess is None:
118
nn.tf_sess = tf.Session(config=nn.tf_sess_config)
119
120
if floatx == "float32":
121
floatx = nn.tf.float32
122
elif floatx == "float16":
123
floatx = nn.tf.float16
124
else:
125
raise ValueError(f"unsupported floatx {floatx}")
126
nn.set_floatx(floatx)
127
nn.set_data_format(data_format)
128
129
@staticmethod
130
def initialize_main_env():
131
Devices.initialize_main_env()
132
133
@staticmethod
134
def set_floatx(tf_dtype):
135
"""
136
set default float type for all layers when dtype is None for them
137
"""
138
nn.floatx = tf_dtype
139
140
@staticmethod
141
def set_data_format(data_format):
142
if data_format != "NHWC" and data_format != "NCHW":
143
raise ValueError(f"unsupported data_format {data_format}")
144
nn.data_format = data_format
145
146
if data_format == "NHWC":
147
nn.conv2d_ch_axis = 3
148
nn.conv2d_spatial_axes = [1,2]
149
elif data_format == "NCHW":
150
nn.conv2d_ch_axis = 1
151
nn.conv2d_spatial_axes = [2,3]
152
153
@staticmethod
154
def get4Dshape ( w, h, c ):
155
"""
156
returns 4D shape based on current data_format
157
"""
158
if nn.data_format == "NHWC":
159
return (None,h,w,c)
160
else:
161
return (None,c,h,w)
162
163
@staticmethod
164
def to_data_format( x, to_data_format, from_data_format):
165
if to_data_format == from_data_format:
166
return x
167
168
if to_data_format == "NHWC":
169
return np.transpose(x, (0,2,3,1) )
170
elif to_data_format == "NCHW":
171
return np.transpose(x, (0,3,1,2) )
172
else:
173
raise ValueError(f"unsupported to_data_format {to_data_format}")
174
175
@staticmethod
176
def getCurrentDeviceConfig():
177
if nn.current_DeviceConfig is None:
178
nn.current_DeviceConfig = DeviceConfig.BestGPU()
179
return nn.current_DeviceConfig
180
181
@staticmethod
182
def setCurrentDeviceConfig(device_config):
183
nn.current_DeviceConfig = device_config
184
185
@staticmethod
186
def reset_session():
187
if nn.tf is not None:
188
if nn.tf_sess is not None:
189
nn.tf.reset_default_graph()
190
nn.tf_sess.close()
191
nn.tf_sess = nn.tf.Session(config=nn.tf_sess_config)
192
193
@staticmethod
194
def close_session():
195
if nn.tf_sess is not None:
196
nn.tf.reset_default_graph()
197
nn.tf_sess.close()
198
nn.tf_sess = None
199
200
@staticmethod
201
def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False):
202
devices = Devices.getDevices()
203
if len(devices) == 0:
204
return []
205
206
all_devices_indexes = [device.index for device in devices]
207
208
if choose_only_one:
209
suggest_best_multi_gpu = False
210
suggest_all_gpu = False
211
212
if suggest_all_gpu:
213
best_device_indexes = all_devices_indexes
214
elif suggest_best_multi_gpu:
215
best_device_indexes = [device.index for device in devices.get_equal_devices(devices.get_best_device()) ]
216
else:
217
best_device_indexes = [ devices.get_best_device().index ]
218
best_device_indexes = ",".join([str(x) for x in best_device_indexes])
219
220
io.log_info ("")
221
if choose_only_one:
222
io.log_info ("Choose one GPU idx.")
223
else:
224
io.log_info ("Choose one or several GPU idxs (separated by comma).")
225
io.log_info ("")
226
227
if allow_cpu:
228
io.log_info ("[CPU] : CPU")
229
for device in devices:
230
io.log_info (f" [{device.index}] : {device.name}")
231
232
io.log_info ("")
233
234
while True:
235
try:
236
if choose_only_one:
237
choosed_idxs = io.input_str("Which GPU index to choose?", best_device_indexes)
238
else:
239
choosed_idxs = io.input_str("Which GPU indexes to choose?", best_device_indexes)
240
241
if allow_cpu and choosed_idxs.lower() == "cpu":
242
choosed_idxs = []
243
break
244
245
choosed_idxs = [ int(x) for x in choosed_idxs.split(',') ]
246
247
if choose_only_one:
248
if len(choosed_idxs) == 1:
249
break
250
else:
251
if all( [idx in all_devices_indexes for idx in choosed_idxs] ):
252
break
253
except:
254
pass
255
io.log_info ("")
256
257
return choosed_idxs
258
259
class DeviceConfig():
260
@staticmethod
261
def ask_choose_device(*args, **kwargs):
262
return nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(*args,**kwargs) )
263
264
def __init__ (self, devices=None):
265
devices = devices or []
266
267
if not isinstance(devices, Devices):
268
devices = Devices(devices)
269
270
self.devices = devices
271
self.cpu_only = len(devices) == 0
272
273
@staticmethod
274
def BestGPU():
275
devices = Devices.getDevices()
276
if len(devices) == 0:
277
return nn.DeviceConfig.CPU()
278
279
return nn.DeviceConfig([devices.get_best_device()])
280
281
@staticmethod
282
def WorstGPU():
283
devices = Devices.getDevices()
284
if len(devices) == 0:
285
return nn.DeviceConfig.CPU()
286
287
return nn.DeviceConfig([devices.get_worst_device()])
288
289
@staticmethod
290
def GPUIndexes(indexes):
291
if len(indexes) != 0:
292
devices = Devices.getDevices().get_devices_from_index_list(indexes)
293
else:
294
devices = []
295
296
return nn.DeviceConfig(devices)
297
298
@staticmethod
299
def CPU():
300
return nn.DeviceConfig([])
301
302