Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ninjaneural
GitHub Repository: ninjaneural/webui
Path: blob/master/misc/directui.py
3275 views
1
import os
2
import sys
3
import time
4
import importlib
5
import signal
6
import re
7
import warnings
8
from fastapi import FastAPI
9
from fastapi.middleware.cors import CORSMiddleware
10
from fastapi.middleware.gzip import GZipMiddleware
11
from packaging import version
12
13
import logging
14
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
15
16
from modules import paths, timer, import_hook, errors
17
18
startup_timer = timer.Timer()
19
20
import torch
21
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
22
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
23
startup_timer.record("import torch")
24
25
import gradio
26
startup_timer.record("import gradio")
27
28
import ldm.modules.encoders.modules
29
startup_timer.record("import ldm")
30
31
from modules import extra_networks, ui_extra_networks_checkpoints
32
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
33
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
34
35
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
36
if ".dev" in torch.__version__ or "+git" in torch.__version__:
37
torch.__long_version__ = torch.__version__
38
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
39
40
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
41
import modules.codeformer_model as codeformer
42
import modules.face_restoration
43
import modules.gfpgan_model as gfpgan
44
import modules.img2img
45
46
import modules.lowvram
47
import modules.scripts
48
import modules.sd_hijack
49
import modules.sd_models
50
import modules.sd_vae
51
import modules.txt2img
52
import modules.script_callbacks
53
import modules.textual_inversion.textual_inversion
54
import modules.progress
55
56
import modules.ui
57
from modules import modelloader
58
from modules.shared import cmd_opts
59
import modules.hypernetworks.hypernetwork
60
61
startup_timer.record("other imports")
62
63
64
if cmd_opts.server_name:
65
server_name = cmd_opts.server_name
66
else:
67
server_name = "0.0.0.0" if cmd_opts.listen else None
68
69
70
def check_versions():
71
if shared.cmd_opts.skip_version_check:
72
return
73
74
expected_torch_version = "1.13.1"
75
76
if version.parse(torch.__version__) < version.parse(expected_torch_version):
77
errors.print_error_explanation(f"""
78
You are running torch {torch.__version__}.
79
The program is tested to work with torch {expected_torch_version}.
80
To reinstall the desired version, run with commandline flag --reinstall-torch.
81
Beware that this will cause a lot of large files to be downloaded, as well as
82
there are reports of issues with training tab on the latest version.
83
84
Use --skip-version-check commandline argument to disable this check.
85
""".strip())
86
87
expected_xformers_version = "0.0.16rc425"
88
if shared.xformers_available:
89
import xformers
90
91
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
92
errors.print_error_explanation(f"""
93
You are running xformers {xformers.__version__}.
94
The program is tested to work with xformers {expected_xformers_version}.
95
To reinstall the desired version, run with commandline flag --reinstall-xformers.
96
97
Use --skip-version-check commandline argument to disable this check.
98
""".strip())
99
100
101
def initialize():
102
check_versions()
103
104
extensions.list_extensions()
105
localization.list_localizations(cmd_opts.localizations_dir)
106
startup_timer.record("list extensions")
107
108
if cmd_opts.ui_debug_mode:
109
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
110
modules.scripts.load_scripts()
111
return
112
113
modelloader.cleanup_models()
114
modules.sd_models.setup_model()
115
startup_timer.record("list SD models")
116
117
codeformer.setup_model(cmd_opts.codeformer_models_path)
118
startup_timer.record("setup codeformer")
119
120
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
121
startup_timer.record("setup gfpgan")
122
123
modelloader.list_builtin_upscalers()
124
startup_timer.record("list builtin upscalers")
125
126
modules.scripts.load_scripts()
127
startup_timer.record("load scripts")
128
129
modelloader.load_upscalers()
130
startup_timer.record("load upscalers")
131
132
modules.sd_vae.refresh_vae_list()
133
startup_timer.record("refresh VAE")
134
135
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
136
startup_timer.record("refresh textual inversion templates")
137
138
try:
139
modules.sd_models.load_model()
140
except Exception as e:
141
errors.display(e, "loading stable diffusion model")
142
print("", file=sys.stderr)
143
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
144
exit(1)
145
startup_timer.record("load SD checkpoint")
146
147
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
148
149
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
150
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
151
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
152
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
153
startup_timer.record("opts onchange")
154
155
shared.reload_hypernetworks()
156
startup_timer.record("reload hypernets")
157
158
ui_extra_networks.intialize()
159
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
160
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
161
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
162
163
extra_networks.initialize()
164
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
165
startup_timer.record("extra networks")
166
167
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
168
169
try:
170
if not os.path.exists(cmd_opts.tls_keyfile):
171
print("Invalid path to TLS keyfile given")
172
if not os.path.exists(cmd_opts.tls_certfile):
173
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
174
except TypeError:
175
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
176
print("TLS setup invalid, running webui without TLS")
177
else:
178
print("Running with TLS")
179
startup_timer.record("TLS")
180
181
# make the program just exit at ctrl+c without waiting for anything
182
def sigint_handler(sig, frame):
183
print(f'Interrupted with signal {sig} in {frame}')
184
os._exit(0)
185
186
signal.signal(signal.SIGINT, sigint_handler)
187
188
189
def setup_middleware(app):
190
app.middleware_stack = None # reset current middleware to allow modifying user provided list
191
app.add_middleware(GZipMiddleware, minimum_size=1000)
192
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
193
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
194
elif cmd_opts.cors_allow_origins:
195
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
196
elif cmd_opts.cors_allow_origins_regex:
197
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
198
app.build_middleware_stack() # rebuild middleware stack on-the-fly
199
200
201
def create_api(app):
202
from modules.api.api import Api
203
api = Api(app, queue_lock)
204
return api
205
206
207
def wait_on_server(demo=None):
208
while 1:
209
time.sleep(0.5)
210
if shared.state.need_restart:
211
shared.state.need_restart = False
212
time.sleep(0.5)
213
demo.close()
214
time.sleep(0.5)
215
break
216
217
218
def api_only():
219
initialize()
220
221
app = FastAPI()
222
setup_middleware(app)
223
api = create_api(app)
224
225
modules.script_callbacks.app_started_callback(None, app)
226
227
print(f"Startup time: {startup_timer.summary()}.")
228
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
229
230
231
def webui():
232
initialize()
233
234
while 1:
235
if shared.opts.clean_temp_dir_at_start:
236
ui_tempdir.cleanup_tmpdr()
237
startup_timer.record("cleanup temp dir")
238
239
modules.script_callbacks.before_ui_callback()
240
startup_timer.record("scripts before_ui_callback")
241
242
shared.demo = modules.ui.create_ui()
243
startup_timer.record("create ui")
244
245
app, local_url, share_url = shared.demo.launch(
246
height=3000,
247
prevent_thread_lock=True
248
)
249
# after initial launch, disable --autolaunch for subsequent restarts
250
cmd_opts.autolaunch = False
251
252
startup_timer.record("gradio launch")
253
254
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
255
initialize_util.setup_middleware(app)
256
257
progress.setup_progress_api(app)
258
ui.setup_ui_api(app)
259
260
ui_extra_networks.add_pages_to_demo(app)
261
262
modules.script_callbacks.app_started_callback(shared.demo, app)
263
startup_timer.record("scripts app_started_callback")
264
265
print(f"Startup time: {startup_timer.summary()}.")
266
267
wait_on_server(shared.demo)
268
print('Restarting UI...')
269
270
startup_timer.reset()
271
272
sd_samplers.set_samplers()
273
274
modules.script_callbacks.script_unloaded_callback()
275
extensions.list_extensions()
276
startup_timer.record("list extensions")
277
278
localization.list_localizations(cmd_opts.localizations_dir)
279
280
modelloader.forbid_loaded_nonbuiltin_upscalers()
281
modules.scripts.reload_scripts()
282
startup_timer.record("load scripts")
283
284
modules.script_callbacks.model_loaded_callback(shared.sd_model)
285
startup_timer.record("model loaded callback")
286
287
modelloader.load_upscalers()
288
startup_timer.record("load upscalers")
289
290
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
291
importlib.reload(module)
292
startup_timer.record("reload script modules")
293
294
modules.sd_models.list_models()
295
startup_timer.record("list SD models")
296
297
shared.reload_hypernetworks()
298
startup_timer.record("reload hypernetworks")
299
300
ui_extra_networks.intialize()
301
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
302
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
303
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
304
305
extra_networks.initialize()
306
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
307
startup_timer.record("initialize extra networks")
308
309
310
if __name__ == "__main__":
311
if cmd_opts.nowebui:
312
api_only()
313
else:
314
webui()
315
316