Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ninjaneural
GitHub Repository: ninjaneural/webui
Path: blob/master/misc/direct/v1.3.2/directui.py
3275 views
1
from __future__ import annotations
2
3
import os
4
import sys
5
import time
6
import importlib
7
import signal
8
import re
9
import warnings
10
import json
11
from threading import Thread
12
from typing import Iterable
13
14
from fastapi import FastAPI, Response
15
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.middleware.gzip import GZipMiddleware
17
from packaging import version
18
19
import logging
20
21
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
22
23
from modules import paths, timer, import_hook, errors # noqa: F401
24
25
startup_timer = timer.Timer()
26
27
import torch
28
import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
29
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
30
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
31
32
33
startup_timer.record("import torch")
34
35
import gradio
36
startup_timer.record("import gradio")
37
38
import ldm.modules.encoders.modules # noqa: F401
39
startup_timer.record("import ldm")
40
41
from modules import extra_networks
42
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
43
44
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
45
if ".dev" in torch.__version__ or "+git" in torch.__version__:
46
torch.__long_version__ = torch.__version__
47
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
48
49
from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
50
import modules.codeformer_model as codeformer
51
import modules.face_restoration
52
import modules.gfpgan_model as gfpgan
53
import modules.img2img
54
55
import modules.lowvram
56
import modules.scripts
57
import modules.sd_hijack
58
import modules.sd_hijack_optimizations
59
import modules.sd_models
60
import modules.sd_vae
61
import modules.txt2img
62
import modules.script_callbacks
63
import modules.textual_inversion.textual_inversion
64
import modules.progress
65
66
import modules.ui
67
from modules import modelloader
68
from modules.shared import cmd_opts
69
import modules.hypernetworks.hypernetwork
70
71
startup_timer.record("other imports")
72
73
74
if cmd_opts.server_name:
75
server_name = cmd_opts.server_name
76
else:
77
server_name = "0.0.0.0" if cmd_opts.listen else None
78
79
80
def fix_asyncio_event_loop_policy():
81
"""
82
The default `asyncio` event loop policy only automatically creates
83
event loops in the main threads. Other threads must create event
84
loops explicitly or `asyncio.get_event_loop` (and therefore
85
`.IOLoop.current`) will fail. Installing this policy allows event
86
loops to be created automatically on any thread, matching the
87
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
88
"""
89
90
import asyncio
91
92
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
93
# "Any thread" and "selector" should be orthogonal, but there's not a clean
94
# interface for composing policies so pick the right base.
95
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
96
else:
97
_BasePolicy = asyncio.DefaultEventLoopPolicy
98
99
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
100
"""Event loop policy that allows loop creation on any thread.
101
Usage::
102
103
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
104
"""
105
106
def get_event_loop(self) -> asyncio.AbstractEventLoop:
107
try:
108
return super().get_event_loop()
109
except (RuntimeError, AssertionError):
110
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
111
# and changed to a RuntimeError in 3.4.3.
112
# "There is no current event loop in thread %r"
113
loop = self.new_event_loop()
114
self.set_event_loop(loop)
115
return loop
116
117
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
118
119
120
def check_versions():
121
if shared.cmd_opts.skip_version_check:
122
return
123
124
expected_torch_version = "2.0.0"
125
126
if version.parse(torch.__version__) < version.parse(expected_torch_version):
127
errors.print_error_explanation(f"""
128
You are running torch {torch.__version__}.
129
The program is tested to work with torch {expected_torch_version}.
130
To reinstall the desired version, run with commandline flag --reinstall-torch.
131
Beware that this will cause a lot of large files to be downloaded, as well as
132
there are reports of issues with training tab on the latest version.
133
134
Use --skip-version-check commandline argument to disable this check.
135
""".strip())
136
137
expected_xformers_version = "0.0.17"
138
if shared.xformers_available:
139
import xformers
140
141
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
142
errors.print_error_explanation(f"""
143
You are running xformers {xformers.__version__}.
144
The program is tested to work with xformers {expected_xformers_version}.
145
To reinstall the desired version, run with commandline flag --reinstall-xformers.
146
147
Use --skip-version-check commandline argument to disable this check.
148
""".strip())
149
150
151
def restore_config_state_file():
152
config_state_file = shared.opts.restore_config_state_file
153
if config_state_file == "":
154
return
155
156
shared.opts.restore_config_state_file = ""
157
shared.opts.save(shared.config_filename)
158
159
if os.path.isfile(config_state_file):
160
print(f"*** About to restore extension state from file: {config_state_file}")
161
with open(config_state_file, "r", encoding="utf-8") as f:
162
config_state = json.load(f)
163
config_states.restore_extension_config(config_state)
164
startup_timer.record("restore extension config")
165
elif config_state_file:
166
print(f"!!! Config state backup not found: {config_state_file}")
167
168
169
def validate_tls_options():
170
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
171
return
172
173
try:
174
if not os.path.exists(cmd_opts.tls_keyfile):
175
print("Invalid path to TLS keyfile given")
176
if not os.path.exists(cmd_opts.tls_certfile):
177
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
178
except TypeError:
179
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
180
print("TLS setup invalid, running webui without TLS")
181
else:
182
print("Running with TLS")
183
startup_timer.record("TLS")
184
185
186
def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
187
"""
188
Convert the gradio_auth and gradio_auth_path commandline arguments into
189
an iterable of (username, password) tuples.
190
"""
191
def process_credential_line(s) -> tuple[str, ...] | None:
192
s = s.strip()
193
if not s:
194
return None
195
return tuple(s.split(':', 1))
196
197
if cmd_opts.gradio_auth:
198
for cred in cmd_opts.gradio_auth.split(','):
199
cred = process_credential_line(cred)
200
if cred:
201
yield cred
202
203
if cmd_opts.gradio_auth_path:
204
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
205
for line in file.readlines():
206
for cred in line.strip().split(','):
207
cred = process_credential_line(cred)
208
if cred:
209
yield cred
210
211
212
def configure_sigint_handler():
213
# make the program just exit at ctrl+c without waiting for anything
214
def sigint_handler(sig, frame):
215
print(f'Interrupted with signal {sig} in {frame}')
216
os._exit(0)
217
218
if not os.environ.get("COVERAGE_RUN"):
219
# Don't install the immediate-quit handler when running under coverage,
220
# as then the coverage report won't be generated.
221
signal.signal(signal.SIGINT, sigint_handler)
222
223
224
def configure_opts_onchange():
225
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
226
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
227
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
228
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
229
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
230
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
231
startup_timer.record("opts onchange")
232
233
234
def initialize():
235
fix_asyncio_event_loop_policy()
236
validate_tls_options()
237
configure_sigint_handler()
238
check_versions()
239
modelloader.cleanup_models()
240
configure_opts_onchange()
241
242
modules.sd_models.setup_model()
243
startup_timer.record("setup SD model")
244
245
codeformer.setup_model(cmd_opts.codeformer_models_path)
246
startup_timer.record("setup codeformer")
247
248
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
249
startup_timer.record("setup gfpgan")
250
251
initialize_rest(reload_script_modules=False)
252
253
254
def initialize_rest(*, reload_script_modules=False):
255
"""
256
Called both from initialize() and when reloading the webui.
257
"""
258
sd_samplers.set_samplers()
259
extensions.list_extensions()
260
startup_timer.record("list extensions")
261
262
restore_config_state_file()
263
264
if cmd_opts.ui_debug_mode:
265
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
266
modules.scripts.load_scripts()
267
return
268
269
modules.sd_models.list_models()
270
startup_timer.record("list SD models")
271
272
localization.list_localizations(cmd_opts.localizations_dir)
273
274
modules.scripts.load_scripts()
275
startup_timer.record("load scripts")
276
277
if reload_script_modules:
278
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
279
importlib.reload(module)
280
startup_timer.record("reload script modules")
281
282
modelloader.load_upscalers()
283
startup_timer.record("load upscalers")
284
285
modules.sd_vae.refresh_vae_list()
286
startup_timer.record("refresh VAE")
287
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
288
startup_timer.record("refresh textual inversion templates")
289
290
modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
291
modules.sd_hijack.list_optimizers()
292
startup_timer.record("scripts list_optimizers")
293
294
def load_model():
295
"""
296
Accesses shared.sd_model property to load model.
297
After it's available, if it has been loaded before this access by some extension,
298
its optimization may be None because the list of optimizaers has neet been filled
299
by that time, so we apply optimization again.
300
"""
301
302
shared.sd_model # noqa: B018
303
304
if modules.sd_hijack.current_optimizer is None:
305
modules.sd_hijack.apply_optimizations()
306
307
Thread(target=load_model).start()
308
309
shared.reload_hypernetworks()
310
startup_timer.record("reload hypernetworks")
311
312
ui_extra_networks.initialize()
313
ui_extra_networks.register_default_pages()
314
315
extra_networks.initialize()
316
extra_networks.register_default_extra_networks()
317
startup_timer.record("initialize extra networks")
318
319
320
def setup_middleware(app):
321
app.middleware_stack = None # reset current middleware to allow modifying user provided list
322
app.add_middleware(GZipMiddleware, minimum_size=1000)
323
configure_cors_middleware(app)
324
app.build_middleware_stack() # rebuild middleware stack on-the-fly
325
326
327
def configure_cors_middleware(app):
328
cors_options = {
329
"allow_methods": ["*"],
330
"allow_headers": ["*"],
331
"allow_credentials": True,
332
}
333
if cmd_opts.cors_allow_origins:
334
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
335
if cmd_opts.cors_allow_origins_regex:
336
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
337
app.add_middleware(CORSMiddleware, **cors_options)
338
339
340
def create_api(app):
341
from modules.api.api import Api
342
api = Api(app, queue_lock)
343
return api
344
345
346
def api_only():
347
initialize()
348
349
app = FastAPI()
350
setup_middleware(app)
351
api = create_api(app)
352
353
modules.script_callbacks.app_started_callback(None, app)
354
355
print(f"Startup time: {startup_timer.summary()}.")
356
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
357
358
359
def stop_route(request):
360
shared.state.server_command = "stop"
361
return Response("Stopping.")
362
363
364
def webui():
365
launch_api = cmd_opts.api
366
initialize()
367
368
while 1:
369
modules.script_callbacks.before_ui_callback()
370
371
shared.demo = modules.ui.create_ui()
372
373
app, local_url, share_url = shared.demo.launch(
374
height=3000,
375
prevent_thread_lock=True
376
)
377
378
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
379
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
380
# running web ui and do whatever the attacker wants, including installing an extension and
381
# running its code. We disable this here. Suggested by RyotaK.
382
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
383
384
setup_middleware(app)
385
386
modules.progress.setup_progress_api(app)
387
modules.ui.setup_ui_api(app)
388
389
ui_extra_networks.add_pages_to_demo(app)
390
391
modules.script_callbacks.app_started_callback(shared.demo, app)
392
startup_timer.record("scripts app_started_callback")
393
394
print(f"Startup time: {startup_timer.summary()}.")
395
396
try:
397
while True:
398
server_command = shared.state.wait_for_server_command(timeout=5)
399
if server_command:
400
if server_command in ("stop", "restart"):
401
break
402
else:
403
print(f"Unknown server command: {server_command}")
404
except KeyboardInterrupt:
405
print('Caught KeyboardInterrupt, stopping...')
406
server_command = "stop"
407
408
if server_command == "stop":
409
print("Stopping server...")
410
# If we catch a keyboard interrupt, we want to stop the server and exit.
411
shared.demo.close()
412
break
413
print('Restarting UI...')
414
shared.demo.close()
415
time.sleep(0.5)
416
startup_timer.reset()
417
modules.script_callbacks.app_reload_callback()
418
startup_timer.record("app reload callback")
419
modules.script_callbacks.script_unloaded_callback()
420
startup_timer.record("scripts unloaded callback")
421
initialize_rest(reload_script_modules=True)
422
423
modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
424
modules.sd_hijack.list_optimizers()
425
startup_timer.record("scripts list_optimizers")
426
427
428
if __name__ == "__main__":
429
if cmd_opts.nowebui:
430
api_only()
431
else:
432
webui()
433
434