143 lines
4.9 KiB
Python
143 lines
4.9 KiB
Python
import importlib
|
|
import logging
|
|
import os
|
|
from inspect import isclass
|
|
from pkgutil import iter_modules
|
|
from typing import final
|
|
|
|
from flask import Blueprint, Flask, request
|
|
from flask_assets import Bundle, Environment
|
|
from flask_babel import Babel
|
|
from flask_wtf import CSRFProtect
|
|
from livereload import Server
|
|
|
|
from starfall.config import Config
|
|
from starfall.db import db, load_schema
|
|
from starfall.types import SnapshotQueue
|
|
from starfall.web.blueprints.base import BaseBlueprint
|
|
from starfall.web.controllers.user import login_manager
|
|
|
|
|
|
@final
|
|
class WebUI:
|
|
def __init__(self):
|
|
self.app: Flask | None = None
|
|
self.assets: Environment | None = None
|
|
self.babel: Babel | None = None
|
|
self.blueprint: Blueprint = Blueprint("starfall", __name__)
|
|
self.config: Config | None = None
|
|
self.queue: SnapshotQueue | None = None
|
|
self.server: Server | None = None
|
|
self.csrf: CSRFProtect | None = None
|
|
|
|
def select_locale(self):
|
|
# user = getattr(g, "user", None)
|
|
# if user is not None:
|
|
# return user.locale
|
|
return request.accept_languages.best_match(["en", "de"])
|
|
|
|
def select_timezone(self):
|
|
# user = getattr(g, "user", None)
|
|
# if user is not None:
|
|
# return user.timezone
|
|
pass
|
|
|
|
def run(self, config: Config, queue: SnapshotQueue):
|
|
self.config = config
|
|
self.queue = queue
|
|
|
|
logging.getLogger("web").debug("Hello from %r", type(self))
|
|
|
|
self.app = Flask(
|
|
import_name=__name__,
|
|
root_path=os.path.realpath("."),
|
|
static_folder=os.path.realpath("./web/static"),
|
|
template_folder="web/templates",
|
|
)
|
|
|
|
self.app.config.update(
|
|
SECRET_KEY=str(self.config.get("web.secret_key")),
|
|
SQLALCHEMY_DATABASE_URI=str(self.config.get("web.database_url")),
|
|
TEMPLATES_AUTO_RELOAD=True,
|
|
)
|
|
self.app.jinja_env.auto_reload = True
|
|
logging.getLogger("web").debug("flask initialized")
|
|
|
|
db.init_app(self.app)
|
|
load_schema()
|
|
|
|
with self.app.app_context():
|
|
db.create_all()
|
|
logging.getLogger("web").debug("db initialized")
|
|
|
|
self.babel = Babel(
|
|
self.app,
|
|
locale_selector=self.select_locale,
|
|
timezone_selector=self.select_timezone,
|
|
)
|
|
logging.getLogger("web").debug("babel initialized")
|
|
|
|
self.assets = Environment(self.app)
|
|
self.assets.load_path = ["web/static/scss", "web/static/css"]
|
|
scss = Bundle(
|
|
"bootstrap/bootstrap.scss",
|
|
"bootstrap/bootstrap-utilities.scss",
|
|
"bootstrap/bootstrap-reboot.scss",
|
|
"bootstrap/bootstrap-grid.scss",
|
|
"bootstrap/bootstrap-icons.scss",
|
|
"starfall/main.scss",
|
|
filters="libsass",
|
|
output="css/main.css",
|
|
depends=["**/*.scss"],
|
|
)
|
|
_ = self.assets.register("scss", scss)
|
|
logging.getLogger("web").debug("assets initialized")
|
|
|
|
self.csrf = CSRFProtect(self.app)
|
|
logging.getLogger("web").debug("csrf connected")
|
|
|
|
login_manager.init_app(self.app)
|
|
|
|
self.import_blueprints("secure", "secure")
|
|
self.import_blueprints("public", "public")
|
|
self.import_blueprints()
|
|
self.app.register_blueprint(
|
|
self.blueprint,
|
|
options={"queue": self.queue},
|
|
)
|
|
logging.getLogger("web").debug("blueprints initialized")
|
|
|
|
self.server = Server(self.app.wsgi_app)
|
|
self.server.watch(os.path.join(str(self.app.template_folder), "**/*.jinja"))
|
|
self.server.watch(os.path.join(str(self.app.static_folder), "**/*.js"))
|
|
self.server.watch(os.path.join(str(self.app.static_folder), "**/*.scss"))
|
|
logging.getLogger("web").debug("livereload initialized")
|
|
|
|
self.server.serve(
|
|
host=self.config.get("web.host"),
|
|
port=self.config.get("web.port"),
|
|
)
|
|
|
|
def import_blueprints(self, path_suffix: str = "", module_suffix: str = ""):
|
|
path = os.path.realpath(os.path.dirname(__file__) + os.sep + "blueprints")
|
|
if len(path_suffix):
|
|
path += os.sep + path_suffix
|
|
|
|
prefix = f"{__name__}.blueprints."
|
|
if len(module_suffix):
|
|
prefix += module_suffix + "."
|
|
|
|
for _, module_name, _ in iter_modules([path], prefix):
|
|
if module_name.endswith("base"):
|
|
continue
|
|
|
|
logging.getLogger("web").debug("Parsing module: %s" % module_name)
|
|
module = importlib.import_module(module_name)
|
|
|
|
for attribute_name in dir(module):
|
|
attribute = getattr(module, attribute_name)
|
|
|
|
if isclass(attribute) and issubclass(attribute, BaseBlueprint):
|
|
globals()[attribute_name] = attribute
|
|
globals()[attribute_name](self.blueprint, self.assets, self.app)
|