141 lines
3.7 KiB
Python
141 lines
3.7 KiB
Python
|
from flask import Blueprint, jsonify, request
|
||
|
from flask_login import login_required, login_user, logout_user
|
||
|
from flask_wtf import FlaskForm
|
||
|
from flask_wtf.csrf import generate_csrf
|
||
|
from models import User
|
||
|
from server import db, login_manager
|
||
|
from sqlalchemy.orm import Session
|
||
|
from werkzeug.datastructures import ImmutableMultiDict
|
||
|
from wtforms import (BooleanField, PasswordField, StringField, SubmitField,
|
||
|
validators)
|
||
|
|
||
|
auth = Blueprint("auth", __name__,)
|
||
|
|
||
|
|
||
|
db_session = Session(db, future=True)
|
||
|
|
||
|
|
||
|
@login_manager.user_loader
|
||
|
def load_user(user_id):
|
||
|
return User.query.get(user_id)
|
||
|
|
||
|
|
||
|
@auth.route("/csrf", methods=["GET"])
|
||
|
def csrf():
|
||
|
return jsonify({"csrfToken": generate_csrf()}), 200
|
||
|
|
||
|
|
||
|
def str_form_errors(form_errors):
|
||
|
str_errors = []
|
||
|
for k, errors in form_errors.items():
|
||
|
if k is None:
|
||
|
k = "Error"
|
||
|
for error in errors:
|
||
|
str_errors.append(f"{k}: {error}")
|
||
|
return ", ".join(str_errors)
|
||
|
|
||
|
|
||
|
class LoginForm(FlaskForm):
|
||
|
username = StringField(
|
||
|
label="Username",
|
||
|
validators=[
|
||
|
validators.InputRequired(),
|
||
|
],
|
||
|
id="username",
|
||
|
default="user",
|
||
|
name="username",
|
||
|
)
|
||
|
password = PasswordField(
|
||
|
label="Password",
|
||
|
validators=[
|
||
|
validators.InputRequired(),
|
||
|
],
|
||
|
id="password",
|
||
|
default="password",
|
||
|
name="password",
|
||
|
)
|
||
|
remember_me = BooleanField(
|
||
|
"Remember me",
|
||
|
)
|
||
|
|
||
|
_fail_message = "wrong credentials"
|
||
|
|
||
|
def validate(self, extra_validators=None):
|
||
|
if not super().validate(extra_validators=extra_validators):
|
||
|
return False
|
||
|
self._user = User.query.filter(User.username == self.username.data).first()
|
||
|
if self._user is None:
|
||
|
self.form_errors.append(self._fail_message)
|
||
|
return False
|
||
|
if not self._user.verify(self.password.data):
|
||
|
self.form_errors.append(self._fail_message)
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
@auth.route("/login", methods=["GET", "POST"])
|
||
|
def login():
|
||
|
form = LoginForm(ImmutableMultiDict(request.get_json()))
|
||
|
if form.validate_on_submit():
|
||
|
login_user(form._user, remember=form.remember_me.data)
|
||
|
return jsonify({"ok": True}), 200
|
||
|
return jsonify({"ok": False, "errors": str_form_errors(form.errors)}), 400
|
||
|
|
||
|
|
||
|
login_manager.login_view = "auth.login"
|
||
|
|
||
|
|
||
|
def username_does_not_exist_validator(form, field):
|
||
|
if User.exists(username=field.data):
|
||
|
raise validators.ValidationError("username already exists")
|
||
|
return True
|
||
|
|
||
|
|
||
|
class RegisterForm(FlaskForm):
|
||
|
username = StringField(
|
||
|
"Username",
|
||
|
validators=[
|
||
|
validators.DataRequired(),
|
||
|
validators.Length(min=3),
|
||
|
username_does_not_exist_validator,
|
||
|
]
|
||
|
)
|
||
|
password = PasswordField(
|
||
|
"Password",
|
||
|
validators=[
|
||
|
validators.DataRequired(),
|
||
|
validators.Length(min=8),
|
||
|
]
|
||
|
)
|
||
|
confirm = PasswordField(
|
||
|
"Repeat password",
|
||
|
validators=[
|
||
|
validators.DataRequired(),
|
||
|
validators.EqualTo("password", message="passwords do not match"),
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
@auth.route("/register", methods=["GET", "POST"])
|
||
|
def register():
|
||
|
form = RegisterForm(ImmutableMultiDict(request.get_json()))
|
||
|
if form.validate_on_submit():
|
||
|
User.register(
|
||
|
username=form.username.data,
|
||
|
password=form.password.data,
|
||
|
)
|
||
|
return jsonify({"ok": True}), 200
|
||
|
return jsonify({"ok": False, "errors": str_form_errors(form.errors)}), 400
|
||
|
|
||
|
|
||
|
@auth.route("/logout", methods=["GET", "POST"])
|
||
|
@login_required
|
||
|
def logout():
|
||
|
logout_user()
|
||
|
return jsonify({"ok": True}), 200
|
||
|
|
||
|
|
||
|
# @login_manager.unauthorized_handler
|
||
|
# def unauthorized():
|
||
|
# return abort(401)
|