Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ninjaneural
GitHub Repository: ninjaneural/webui
Path: blob/master/misc/direct/v1.4.1/directui.py
3275 views
1
from __future__ import annotations
2
3
import os
4
import sys
5
import time
6
import importlib
7
import signal
8
import re
9
import warnings
10
import json
11
from threading import Thread
12
from typing import Iterable
13
14
from fastapi import FastAPI, Response
15
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.middleware.gzip import GZipMiddleware
17
from packaging import version
18
19
import logging
20
21
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
22
23
from modules import paths, timer, import_hook, errors, devices # noqa: F401
24
25
startup_timer = timer.startup_timer
26
27
import torch
28
import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
29
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
30
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
31
32
33
startup_timer.record("import torch")
34
35
import gradio
36
startup_timer.record("import gradio")
37
38
import ldm.modules.encoders.modules # noqa: F401
39
startup_timer.record("import ldm")
40
41
from modules import extra_networks
42
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
43
44
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
45
if ".dev" in torch.__version__ or "+git" in torch.__version__:
46
torch.__long_version__ = torch.__version__
47
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
48
49
from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
50
import modules.codeformer_model as codeformer
51
import modules.face_restoration
52
import modules.gfpgan_model as gfpgan
53
import modules.img2img
54
55
import modules.lowvram
56
import modules.scripts
57
import modules.sd_hijack
58
import modules.sd_hijack_optimizations
59
import modules.sd_models
60
import modules.sd_vae
61
import modules.sd_unet
62
import modules.txt2img
63
import modules.script_callbacks
64
import modules.textual_inversion.textual_inversion
65
import modules.progress
66
67
import modules.ui
68
from modules import modelloader
69
from modules.shared import cmd_opts
70
import modules.hypernetworks.hypernetwork
71
72
startup_timer.record("other imports")
73
74
75
if cmd_opts.server_name:
76
server_name = cmd_opts.server_name
77
else:
78
server_name = "0.0.0.0" if cmd_opts.listen else None
79
80
81
def fix_asyncio_event_loop_policy():
82
"""
83
The default `asyncio` event loop policy only automatically creates
84
event loops in the main threads. Other threads must create event
85
loops explicitly or `asyncio.get_event_loop` (and therefore
86
`.IOLoop.current`) will fail. Installing this policy allows event
87
loops to be created automatically on any thread, matching the
88
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
89
"""
90
91
import asyncio
92
93
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
94
# "Any thread" and "selector" should be orthogonal, but there's not a clean
95
# interface for composing policies so pick the right base.
96
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
97
else:
98
_BasePolicy = asyncio.DefaultEventLoopPolicy
99
100
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
101
"""Event loop policy that allows loop creation on any thread.
102
Usage::
103
104
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
105
"""
106
107
def get_event_loop(self) -> asyncio.AbstractEventLoop:
108
try:
109
return super().get_event_loop()
110
except (RuntimeError, AssertionError):
111
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
112
# and changed to a RuntimeError in 3.4.3.
113
# "There is no current event loop in thread %r"
114
loop = self.new_event_loop()
115
self.set_event_loop(loop)
116
return loop
117
118
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
119
120
121
def check_versions():
122
if shared.cmd_opts.skip_version_check:
123
return
124
125
expected_torch_version = "2.0.0"
126
127
if version.parse(torch.__version__) < version.parse(expected_torch_version):
128
errors.print_error_explanation(f"""
129
You are running torch {torch.__version__}.
130
The program is tested to work with torch {expected_torch_version}.
131
To reinstall the desired version, run with commandline flag --reinstall-torch.
132
Beware that this will cause a lot of large files to be downloaded, as well as
133
there are reports of issues with training tab on the latest version.
134
135
Use --skip-version-check commandline argument to disable this check.
136
""".strip())
137
138
expected_xformers_version = "0.0.20"
139
if shared.xformers_available:
140
import xformers
141
142
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
143
errors.print_error_explanation(f"""
144
You are running xformers {xformers.__version__}.
145
The program is tested to work with xformers {expected_xformers_version}.
146
To reinstall the desired version, run with commandline flag --reinstall-xformers.
147
148
Use --skip-version-check commandline argument to disable this check.
149
""".strip())
150
151
152
def restore_config_state_file():
153
config_state_file = shared.opts.restore_config_state_file
154
if config_state_file == "":
155
return
156
157
shared.opts.restore_config_state_file = ""
158
shared.opts.save(shared.config_filename)
159
160
if os.path.isfile(config_state_file):
161
print(f"*** About to restore extension state from file: {config_state_file}")
162
with open(config_state_file, "r", encoding="utf-8") as f:
163
config_state = json.load(f)
164
config_states.restore_extension_config(config_state)
165
startup_timer.record("restore extension config")
166
elif config_state_file:
167
print(f"!!! Config state backup not found: {config_state_file}")
168
169
170
def validate_tls_options():
171
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
172
return
173
174
try:
175
if not os.path.exists(cmd_opts.tls_keyfile):
176
print("Invalid path to TLS keyfile given")
177
if not os.path.exists(cmd_opts.tls_certfile):
178
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
179
except TypeError:
180
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
181
print("TLS setup invalid, running webui without TLS")
182
else:
183
print("Running with TLS")
184
startup_timer.record("TLS")
185
186
187
def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
188
"""
189
Convert the gradio_auth and gradio_auth_path commandline arguments into
190
an iterable of (username, password) tuples.
191
"""
192
def process_credential_line(s) -> tuple[str, ...] | None:
193
s = s.strip()
194
if not s:
195
return None
196
return tuple(s.split(':', 1))
197
198
if cmd_opts.gradio_auth:
199
for cred in cmd_opts.gradio_auth.split(','):
200
cred = process_credential_line(cred)
201
if cred:
202
yield cred
203
204
if cmd_opts.gradio_auth_path:
205
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
206
for line in file.readlines():
207
for cred in line.strip().split(','):
208
cred = process_credential_line(cred)
209
if cred:
210
yield cred
211
212
213
def configure_sigint_handler():
214
# make the program just exit at ctrl+c without waiting for anything
215
def sigint_handler(sig, frame):
216
print(f'Interrupted with signal {sig} in {frame}')
217
os._exit(0)
218
219
if not os.environ.get("COVERAGE_RUN"):
220
# Don't install the immediate-quit handler when running under coverage,
221
# as then the coverage report won't be generated.
222
signal.signal(signal.SIGINT, sigint_handler)
223
224
225
def configure_opts_onchange():
226
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
227
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
228
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
229
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
230
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
231
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
232
startup_timer.record("opts onchange")
233
234
235
def initialize():
236
fix_asyncio_event_loop_policy()
237
validate_tls_options()
238
configure_sigint_handler()
239
check_versions()
240
modelloader.cleanup_models()
241
configure_opts_onchange()
242
243
modules.sd_models.setup_model()
244
startup_timer.record("setup SD model")
245
246
codeformer.setup_model(cmd_opts.codeformer_models_path)
247
startup_timer.record("setup codeformer")
248
249
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
250
startup_timer.record("setup gfpgan")
251
252
initialize_rest(reload_script_modules=False)
253
254
255
def initialize_rest(*, reload_script_modules=False):
256
"""
257
Called both from initialize() and when reloading the webui.
258
"""
259
sd_samplers.set_samplers()
260
extensions.list_extensions()
261
startup_timer.record("list extensions")
262
263
restore_config_state_file()
264
265
if cmd_opts.ui_debug_mode:
266
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
267
modules.scripts.load_scripts()
268
return
269
270
modules.sd_models.list_models()
271
startup_timer.record("list SD models")
272
273
localization.list_localizations(cmd_opts.localizations_dir)
274
275
with startup_timer.subcategory("load scripts"):
276
modules.scripts.load_scripts()
277
278
if reload_script_modules:
279
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
280
importlib.reload(module)
281
startup_timer.record("reload script modules")
282
283
modelloader.load_upscalers()
284
startup_timer.record("load upscalers")
285
286
modules.sd_vae.refresh_vae_list()
287
startup_timer.record("refresh VAE")
288
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
289
startup_timer.record("refresh textual inversion templates")
290
291
modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
292
modules.sd_hijack.list_optimizers()
293
startup_timer.record("scripts list_optimizers")
294
295
modules.sd_unet.list_unets()
296
startup_timer.record("scripts list_unets")
297
298
def load_model():
299
"""
300
Accesses shared.sd_model property to load model.
301
After it's available, if it has been loaded before this access by some extension,
302
its optimization may be None because the list of optimizaers has neet been filled
303
by that time, so we apply optimization again.
304
"""
305
306
shared.sd_model # noqa: B018
307
308
if modules.sd_hijack.current_optimizer is None:
309
modules.sd_hijack.apply_optimizations()
310
311
Thread(target=load_model).start()
312
313
Thread(target=devices.first_time_calculation).start()
314
315
shared.reload_hypernetworks()
316
startup_timer.record("reload hypernetworks")
317
318
ui_extra_networks.initialize()
319
ui_extra_networks.register_default_pages()
320
321
extra_networks.initialize()
322
extra_networks.register_default_extra_networks()
323
startup_timer.record("initialize extra networks")
324
325
326
def setup_middleware(app):
327
app.middleware_stack = None # reset current middleware to allow modifying user provided list
328
app.add_middleware(GZipMiddleware, minimum_size=1000)
329
configure_cors_middleware(app)
330
app.build_middleware_stack() # rebuild middleware stack on-the-fly
331
332
333
def configure_cors_middleware(app):
334
cors_options = {
335
"allow_methods": ["*"],
336
"allow_headers": ["*"],
337
"allow_credentials": True,
338
}
339
if cmd_opts.cors_allow_origins:
340
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
341
if cmd_opts.cors_allow_origins_regex:
342
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
343
app.add_middleware(CORSMiddleware, **cors_options)
344
345
346
def create_api(app):
347
from modules.api.api import Api
348
api = Api(app, queue_lock)
349
return api
350
351
352
def api_only():
353
initialize()
354
355
app = FastAPI()
356
setup_middleware(app)
357
api = create_api(app)
358
359
modules.script_callbacks.app_started_callback(None, app)
360
361
print(f"Startup time: {startup_timer.summary()}.")
362
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
363
364
365
def stop_route(request):
366
shared.state.server_command = "stop"
367
return Response("Stopping.")
368
369
370
def webui():
371
launch_api = cmd_opts.api
372
initialize()
373
374
while 1:
375
if shared.opts.clean_temp_dir_at_start:
376
ui_tempdir.cleanup_tmpdr()
377
startup_timer.record("cleanup temp dir")
378
379
modules.script_callbacks.before_ui_callback()
380
381
shared.demo = modules.ui.create_ui()
382
383
app, local_url, share_url = shared.demo.launch(
384
height=3000,
385
prevent_thread_lock=True
386
)
387
388
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
389
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
390
# running web ui and do whatever the attacker wants, including installing an extension and
391
# running its code. We disable this here. Suggested by RyotaK.
392
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
393
394
setup_middleware(app)
395
396
modules.progress.setup_progress_api(app)
397
modules.ui.setup_ui_api(app)
398
399
if launch_api:
400
create_api(app)
401
402
ui_extra_networks.add_pages_to_demo(app)
403
404
startup_timer.record("add APIs")
405
406
with startup_timer.subcategory("app_started_callback"):
407
modules.script_callbacks.app_started_callback(shared.demo, app)
408
409
timer.startup_record = startup_timer.dump()
410
print(f"Startup time: {startup_timer.summary()}.")
411
412
if cmd_opts.subpath:
413
redirector = FastAPI()
414
redirector.get("/")
415
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
416
417
try:
418
while True:
419
server_command = shared.state.wait_for_server_command(timeout=5)
420
if server_command:
421
if server_command in ("stop", "restart"):
422
break
423
else:
424
print(f"Unknown server command: {server_command}")
425
except KeyboardInterrupt:
426
print('Caught KeyboardInterrupt, stopping...')
427
server_command = "stop"
428
429
if server_command == "stop":
430
print("Stopping server...")
431
# If we catch a keyboard interrupt, we want to stop the server and exit.
432
shared.demo.close()
433
break
434
435
print('Restarting UI...')
436
shared.demo.close()
437
time.sleep(0.5)
438
startup_timer.reset()
439
modules.script_callbacks.app_reload_callback()
440
startup_timer.record("app reload callback")
441
modules.script_callbacks.script_unloaded_callback()
442
startup_timer.record("scripts unloaded callback")
443
initialize_rest(reload_script_modules=True)
444
445
446
if __name__ == "__main__":
447
if cmd_opts.nowebui:
448
api_only()
449
else:
450
webui()
451
452