Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ninjaneural
GitHub Repository: ninjaneural/webui
Path: blob/master/misc/direct/v1.5.2/directui.py
3275 views
1
from __future__ import annotations
2
3
import os
4
import sys
5
import time
6
import importlib
7
import signal
8
import re
9
import warnings
10
import json
11
from threading import Thread
12
from typing import Iterable
13
14
from fastapi import FastAPI
15
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.middleware.gzip import GZipMiddleware
17
from packaging import version
18
19
import logging
20
21
# We can't use cmd_opts for this because it will not have been initialized at this point.
22
log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
23
if log_level:
24
log_level = getattr(logging, log_level.upper(), None) or logging.INFO
25
logging.basicConfig(
26
level=log_level,
27
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
28
datefmt='%Y-%m-%d %H:%M:%S',
29
)
30
31
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
32
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
33
34
from modules import timer
35
startup_timer = timer.startup_timer
36
startup_timer.record("launcher")
37
38
import torch
39
import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
40
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
41
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
42
startup_timer.record("import torch")
43
44
import gradio # noqa: F401
45
startup_timer.record("import gradio")
46
47
from modules import paths, timer, import_hook, errors, devices # noqa: F401
48
startup_timer.record("setup paths")
49
50
import ldm.modules.encoders.modules # noqa: F401
51
startup_timer.record("import ldm")
52
53
from modules import extra_networks
54
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
55
56
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
57
if ".dev" in torch.__version__ or "+git" in torch.__version__:
58
torch.__long_version__ = torch.__version__
59
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
60
61
from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
62
import modules.codeformer_model as codeformer
63
import modules.face_restoration
64
import modules.gfpgan_model as gfpgan
65
import modules.img2img
66
67
import modules.lowvram
68
import modules.scripts
69
import modules.sd_hijack
70
import modules.sd_hijack_optimizations
71
import modules.sd_models
72
import modules.sd_vae
73
import modules.sd_unet
74
import modules.txt2img
75
import modules.script_callbacks
76
import modules.textual_inversion.textual_inversion
77
import modules.progress
78
79
import modules.ui
80
from modules import modelloader
81
from modules.shared import cmd_opts
82
import modules.hypernetworks.hypernetwork
83
84
startup_timer.record("other imports")
85
86
87
if cmd_opts.server_name:
88
server_name = cmd_opts.server_name
89
else:
90
server_name = "0.0.0.0" if cmd_opts.listen else None
91
92
93
def fix_asyncio_event_loop_policy():
94
"""
95
The default `asyncio` event loop policy only automatically creates
96
event loops in the main threads. Other threads must create event
97
loops explicitly or `asyncio.get_event_loop` (and therefore
98
`.IOLoop.current`) will fail. Installing this policy allows event
99
loops to be created automatically on any thread, matching the
100
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
101
"""
102
103
import asyncio
104
105
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
106
# "Any thread" and "selector" should be orthogonal, but there's not a clean
107
# interface for composing policies so pick the right base.
108
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
109
else:
110
_BasePolicy = asyncio.DefaultEventLoopPolicy
111
112
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
113
"""Event loop policy that allows loop creation on any thread.
114
Usage::
115
116
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
117
"""
118
119
def get_event_loop(self) -> asyncio.AbstractEventLoop:
120
try:
121
return super().get_event_loop()
122
except (RuntimeError, AssertionError):
123
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
124
# and changed to a RuntimeError in 3.4.3.
125
# "There is no current event loop in thread %r"
126
loop = self.new_event_loop()
127
self.set_event_loop(loop)
128
return loop
129
130
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
131
132
133
def check_versions():
134
if shared.cmd_opts.skip_version_check:
135
return
136
137
expected_torch_version = "2.0.0"
138
139
if version.parse(torch.__version__) < version.parse(expected_torch_version):
140
errors.print_error_explanation(f"""
141
You are running torch {torch.__version__}.
142
The program is tested to work with torch {expected_torch_version}.
143
To reinstall the desired version, run with commandline flag --reinstall-torch.
144
Beware that this will cause a lot of large files to be downloaded, as well as
145
there are reports of issues with training tab on the latest version.
146
147
Use --skip-version-check commandline argument to disable this check.
148
""".strip())
149
150
expected_xformers_version = "0.0.20"
151
if shared.xformers_available:
152
import xformers
153
154
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
155
errors.print_error_explanation(f"""
156
You are running xformers {xformers.__version__}.
157
The program is tested to work with xformers {expected_xformers_version}.
158
To reinstall the desired version, run with commandline flag --reinstall-xformers.
159
160
Use --skip-version-check commandline argument to disable this check.
161
""".strip())
162
163
164
def restore_config_state_file():
165
config_state_file = shared.opts.restore_config_state_file
166
if config_state_file == "":
167
return
168
169
shared.opts.restore_config_state_file = ""
170
shared.opts.save(shared.config_filename)
171
172
if os.path.isfile(config_state_file):
173
print(f"*** About to restore extension state from file: {config_state_file}")
174
with open(config_state_file, "r", encoding="utf-8") as f:
175
config_state = json.load(f)
176
config_states.restore_extension_config(config_state)
177
startup_timer.record("restore extension config")
178
elif config_state_file:
179
print(f"!!! Config state backup not found: {config_state_file}")
180
181
182
def validate_tls_options():
183
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
184
return
185
186
try:
187
if not os.path.exists(cmd_opts.tls_keyfile):
188
print("Invalid path to TLS keyfile given")
189
if not os.path.exists(cmd_opts.tls_certfile):
190
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
191
except TypeError:
192
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
193
print("TLS setup invalid, running webui without TLS")
194
else:
195
print("Running with TLS")
196
startup_timer.record("TLS")
197
198
199
def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
200
"""
201
Convert the gradio_auth and gradio_auth_path commandline arguments into
202
an iterable of (username, password) tuples.
203
"""
204
def process_credential_line(s) -> tuple[str, ...] | None:
205
s = s.strip()
206
if not s:
207
return None
208
return tuple(s.split(':', 1))
209
210
if cmd_opts.gradio_auth:
211
for cred in cmd_opts.gradio_auth.split(','):
212
cred = process_credential_line(cred)
213
if cred:
214
yield cred
215
216
if cmd_opts.gradio_auth_path:
217
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
218
for line in file.readlines():
219
for cred in line.strip().split(','):
220
cred = process_credential_line(cred)
221
if cred:
222
yield cred
223
224
225
def configure_sigint_handler():
226
# make the program just exit at ctrl+c without waiting for anything
227
def sigint_handler(sig, frame):
228
print(f'Interrupted with signal {sig} in {frame}')
229
os._exit(0)
230
231
if not os.environ.get("COVERAGE_RUN"):
232
# Don't install the immediate-quit handler when running under coverage,
233
# as then the coverage report won't be generated.
234
signal.signal(signal.SIGINT, sigint_handler)
235
236
237
def configure_opts_onchange():
238
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
239
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
240
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
241
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
242
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
243
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
244
startup_timer.record("opts onchange")
245
246
247
def initialize():
248
fix_asyncio_event_loop_policy()
249
validate_tls_options()
250
configure_sigint_handler()
251
check_versions()
252
modelloader.cleanup_models()
253
configure_opts_onchange()
254
255
modules.sd_models.setup_model()
256
startup_timer.record("setup SD model")
257
258
codeformer.setup_model(cmd_opts.codeformer_models_path)
259
startup_timer.record("setup codeformer")
260
261
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
262
startup_timer.record("setup gfpgan")
263
264
initialize_rest(reload_script_modules=False)
265
266
267
def initialize_rest(*, reload_script_modules=False):
268
"""
269
Called both from initialize() and when reloading the webui.
270
"""
271
sd_samplers.set_samplers()
272
extensions.list_extensions()
273
startup_timer.record("list extensions")
274
275
restore_config_state_file()
276
277
if cmd_opts.ui_debug_mode:
278
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
279
modules.scripts.load_scripts()
280
return
281
282
modules.sd_models.list_models()
283
startup_timer.record("list SD models")
284
285
localization.list_localizations(cmd_opts.localizations_dir)
286
287
with startup_timer.subcategory("load scripts"):
288
modules.scripts.load_scripts()
289
290
if reload_script_modules:
291
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
292
importlib.reload(module)
293
startup_timer.record("reload script modules")
294
295
modelloader.load_upscalers()
296
startup_timer.record("load upscalers")
297
298
modules.sd_vae.refresh_vae_list()
299
startup_timer.record("refresh VAE")
300
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
301
startup_timer.record("refresh textual inversion templates")
302
303
modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
304
modules.sd_hijack.list_optimizers()
305
startup_timer.record("scripts list_optimizers")
306
307
modules.sd_unet.list_unets()
308
startup_timer.record("scripts list_unets")
309
310
def load_model():
311
"""
312
Accesses shared.sd_model property to load model.
313
After it's available, if it has been loaded before this access by some extension,
314
its optimization may be None because the list of optimizaers has neet been filled
315
by that time, so we apply optimization again.
316
"""
317
318
shared.sd_model # noqa: B018
319
320
if modules.sd_hijack.current_optimizer is None:
321
modules.sd_hijack.apply_optimizations()
322
323
Thread(target=load_model).start()
324
325
Thread(target=devices.first_time_calculation).start()
326
327
shared.reload_hypernetworks()
328
startup_timer.record("reload hypernetworks")
329
330
ui_extra_networks.initialize()
331
ui_extra_networks.register_default_pages()
332
333
extra_networks.initialize()
334
extra_networks.register_default_extra_networks()
335
startup_timer.record("initialize extra networks")
336
337
338
def setup_middleware(app):
339
app.middleware_stack = None # reset current middleware to allow modifying user provided list
340
app.add_middleware(GZipMiddleware, minimum_size=1000)
341
configure_cors_middleware(app)
342
app.build_middleware_stack() # rebuild middleware stack on-the-fly
343
344
345
def configure_cors_middleware(app):
346
cors_options = {
347
"allow_methods": ["*"],
348
"allow_headers": ["*"],
349
"allow_credentials": True,
350
}
351
if cmd_opts.cors_allow_origins:
352
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
353
if cmd_opts.cors_allow_origins_regex:
354
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
355
app.add_middleware(CORSMiddleware, **cors_options)
356
357
358
def create_api(app):
359
from modules.api.api import Api
360
api = Api(app, queue_lock)
361
return api
362
363
364
def api_only():
365
initialize()
366
367
app = FastAPI()
368
setup_middleware(app)
369
api = create_api(app)
370
371
modules.script_callbacks.app_started_callback(None, app)
372
373
print(f"Startup time: {startup_timer.summary()}.")
374
api.launch(
375
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
376
port=cmd_opts.port if cmd_opts.port else 7861,
377
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
378
)
379
380
381
def webui():
382
launch_api = cmd_opts.api
383
initialize()
384
385
while 1:
386
if shared.opts.clean_temp_dir_at_start:
387
ui_tempdir.cleanup_tmpdr()
388
startup_timer.record("cleanup temp dir")
389
390
modules.script_callbacks.before_ui_callback()
391
392
shared.demo = modules.ui.create_ui()
393
394
app, local_url, share_url = shared.demo.launch(
395
height=3000,
396
prevent_thread_lock=True
397
)
398
399
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
400
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
401
# running web ui and do whatever the attacker wants, including installing an extension and
402
# running its code. We disable this here. Suggested by RyotaK.
403
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
404
405
setup_middleware(app)
406
407
modules.progress.setup_progress_api(app)
408
modules.ui.setup_ui_api(app)
409
410
if launch_api:
411
create_api(app)
412
413
ui_extra_networks.add_pages_to_demo(app)
414
415
startup_timer.record("add APIs")
416
417
with startup_timer.subcategory("app_started_callback"):
418
modules.script_callbacks.app_started_callback(shared.demo, app)
419
420
timer.startup_record = startup_timer.dump()
421
print(f"Startup time: {startup_timer.summary()}.")
422
423
try:
424
while True:
425
server_command = shared.state.wait_for_server_command(timeout=5)
426
if server_command:
427
if server_command in ("stop", "restart"):
428
break
429
else:
430
print(f"Unknown server command: {server_command}")
431
except KeyboardInterrupt:
432
print('Caught KeyboardInterrupt, stopping...')
433
server_command = "stop"
434
435
if server_command == "stop":
436
print("Stopping server...")
437
# If we catch a keyboard interrupt, we want to stop the server and exit.
438
shared.demo.close()
439
break
440
441
print('Restarting UI...')
442
shared.demo.close()
443
time.sleep(0.5)
444
startup_timer.reset()
445
modules.script_callbacks.app_reload_callback()
446
startup_timer.record("app reload callback")
447
modules.script_callbacks.script_unloaded_callback()
448
startup_timer.record("scripts unloaded callback")
449
initialize_rest(reload_script_modules=True)
450
451
452
if __name__ == "__main__":
453
if cmd_opts.nowebui:
454
api_only()
455
else:
456
webui()
457
458