You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
273 lines
6.9 KiB
273 lines
6.9 KiB
package server
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gorm.io/gorm"
|
|
|
|
"git.chrishayward.xyz/x/users/proto"
|
|
"git.chrishayward.xyz/x/users/server/models"
|
|
)
|
|
|
|
type usersServer struct {
|
|
proto.UsersServer
|
|
secret string
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewUsersServer(secret string, db *gorm.DB) proto.UsersServer {
|
|
db.AutoMigrate(&models.User{}, &models.Role{}, &models.Session{}, &models.PasswordToken{})
|
|
return &usersServer{
|
|
secret: secret,
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
func (m *usersServer) Register(ctx context.Context, in *proto.RegisterRequest) (*proto.RegisterResponse, error) {
|
|
// Make sure both passwords are included and match.
|
|
if in.Form.Password == nil || in.Form.PasswordAgain == nil {
|
|
return nil, errors.New("Must include password(s).")
|
|
}
|
|
|
|
if *in.Form.Password != *in.Form.PasswordAgain {
|
|
return nil, errors.New("Passwords do not match.")
|
|
}
|
|
|
|
// Check for an existing user.
|
|
var u models.User
|
|
if m.db.Where(&u, "email = ?", in.Form.Email).RowsAffected > 0 {
|
|
return nil, errors.New("User already exists.")
|
|
}
|
|
|
|
// Encode the password.
|
|
bytes, err := bcrypt.GenerateFromPassword([]byte(*in.Form.Password), bcrypt.MaxCost)
|
|
if err != nil {
|
|
log.Fatalf("Failed to encode password: %v", err)
|
|
return nil, errors.New("Failed to encode password.")
|
|
}
|
|
|
|
// Create the new user.
|
|
u.UUID = uuid.NewString()
|
|
u.Email = in.Form.Email
|
|
u.Password = string(bytes)
|
|
if m.db.Create(&u).RowsAffected == 0 {
|
|
log.Fatalf("Failed to save user: %v", err)
|
|
return nil, errors.New("Failed to save user.")
|
|
}
|
|
|
|
// Return the response.
|
|
return &proto.RegisterResponse{}, nil
|
|
}
|
|
|
|
func (m *usersServer) Login(ctx context.Context, in *proto.LoginRequest) (*proto.LoginResponse, error) {
|
|
// Make sure the password is included.
|
|
if in.Form.Password == nil {
|
|
return nil, errors.New("Password must be included.")
|
|
}
|
|
|
|
// Find the user.
|
|
var u models.User
|
|
if m.db.Where(&u, "email = ?", in.Form.Email).RowsAffected > 0 {
|
|
return nil, errors.New("User not found.")
|
|
}
|
|
|
|
// Compare the passwords.
|
|
if err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(*in.Form.Password)); err != nil {
|
|
return nil, errors.New("Passwords do not match.")
|
|
}
|
|
|
|
// Create a session.
|
|
s := &models.Session{
|
|
Token: uuid.NewString(),
|
|
Expires: time.Now().AddDate(0, 0, 1).UnixNano(),
|
|
UserID: u.ID,
|
|
}
|
|
|
|
// Save the token.
|
|
if m.db.Create(&s).RowsAffected == 0 {
|
|
return nil, errors.New("Failed to create session.")
|
|
}
|
|
|
|
// Return the response.
|
|
return &proto.LoginResponse{
|
|
Token: &proto.UserToken{
|
|
Token: s.Token,
|
|
Expires: &s.Expires,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (m *usersServer) Logout(ctx context.Context, in *proto.LogoutRequest) (*proto.LogoutResponse, error) {
|
|
// Find the session.
|
|
var s models.Session
|
|
if m.db.First(&s, "token = ?", in.Token.Token).RowsAffected == 0 {
|
|
return nil, errors.New("Failed to find session.")
|
|
}
|
|
|
|
// Expire the token.
|
|
s.Expires = time.Now().UnixNano()
|
|
if m.db.Save(&s).RowsAffected == 0 {
|
|
return nil, errors.New("Failed to close session.")
|
|
}
|
|
|
|
return &proto.LogoutResponse{}, nil
|
|
}
|
|
|
|
func (m *usersServer) Authorize(ctx context.Context, in *proto.AuthorizeRequest) (*proto.AuthorizeResponse, error) {
|
|
// Make sure the secrets match.
|
|
if in.Secret != m.secret {
|
|
return nil, errors.New("Secrets do not match.")
|
|
}
|
|
|
|
// Find the session.
|
|
var s models.Session
|
|
if m.db.First(&s, "token = ?", in.Token.Token).RowsAffected == 0 {
|
|
return nil, errors.New("Session not found.")
|
|
}
|
|
|
|
// Make sure the session hasn't expired.
|
|
if time.Now().UnixNano() > s.Expires {
|
|
return nil, errors.New("Token is expired.")
|
|
}
|
|
|
|
// Find the user.
|
|
var u models.User
|
|
if m.db.Model(&models.User{}).Preload("Roles").First(&u, "id = ?", s.UserID).RowsAffected == 0 {
|
|
return nil, errors.New("Failed to load roles.")
|
|
}
|
|
|
|
// Return the response.
|
|
res := &proto.AuthorizeResponse{
|
|
User: &proto.UserInfo{
|
|
Id: int64(u.ID),
|
|
Uuid: u.UUID,
|
|
},
|
|
}
|
|
|
|
for _, r := range u.Roles {
|
|
res.Roles = append(res.Roles, &proto.UserRole{
|
|
Id: int64(r.ID),
|
|
Name: r.Name,
|
|
})
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (m *usersServer) ResetPassword(ctx context.Context, in *proto.ResetPasswordRequest) (*proto.ResetPasswordResponse, error) {
|
|
// Find the u.
|
|
var u models.User
|
|
if m.db.First(&u, "email = ?", in.Form.Email).RowsAffected == 0 {
|
|
return nil, errors.New("User not found.")
|
|
}
|
|
|
|
// Generate a reset token.
|
|
rt := &models.PasswordToken{
|
|
UserID: u.ID,
|
|
Token: uuid.NewString(),
|
|
Expires: time.Now().UnixNano(),
|
|
}
|
|
|
|
// Save the token.
|
|
if m.db.Create(rt).RowsAffected == 0 {
|
|
return nil, errors.New("Failed to create token.")
|
|
}
|
|
|
|
// Return the response.
|
|
return &proto.ResetPasswordResponse{
|
|
Token: &proto.UserToken{
|
|
Token: rt.Token,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (m *usersServer) ChangePassword(ctx context.Context, in *proto.ChangePasswordRequest) (*proto.ChangePasswordResponse, error) {
|
|
// Find the reset token.
|
|
var rt models.PasswordToken
|
|
if m.db.First(&rt, "token = ?", in.Token.Token).RowsAffected == 0 {
|
|
return nil, errors.New("Token not found.")
|
|
}
|
|
|
|
// Find the user.
|
|
var u models.User
|
|
if m.db.First(&u, "id = ?", rt.UserID).RowsAffected == 0 {
|
|
return nil, errors.New("User not found.")
|
|
}
|
|
|
|
// Update the password.
|
|
bytes, err := bcrypt.GenerateFromPassword([]byte(*in.Form.Password), bcrypt.MaxCost)
|
|
if err != nil {
|
|
log.Fatalf("Failed to encode password: %v", err)
|
|
return nil, errors.New("Failed to encode password.")
|
|
}
|
|
|
|
u.Password = string(bytes)
|
|
if m.db.Save(u).RowsAffected == 0 {
|
|
return nil, errors.New("Failed to update password.")
|
|
}
|
|
|
|
// Expire current token.
|
|
rt.Expires = time.Now().UnixNano()
|
|
if m.db.Save(&rt).RowsAffected == 0 {
|
|
return nil, errors.New("Failed to update password.")
|
|
}
|
|
|
|
// Return the response.
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *usersServer) ListRoles(ctx context.Context, in *proto.ListRolesRequest) (*proto.ListRolesResponse, error) {
|
|
// Make sure the secrets match.
|
|
if in.Secret != m.secret {
|
|
return nil, errors.New("Secrets do not match.")
|
|
}
|
|
|
|
// Get all of the available roles.
|
|
var roles []models.Role
|
|
if m.db.Find(&roles).RowsAffected == 0 {
|
|
return nil, errors.New("Failed to find roles.")
|
|
}
|
|
|
|
// Return the response.
|
|
res := &proto.ListRolesResponse{}
|
|
for _, r := range roles {
|
|
res.Roles = append(res.Roles, &proto.UserRole{
|
|
Id: int64(r.ID),
|
|
Name: r.Name,
|
|
})
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func (m *usersServer) SetRoles(ctx context.Context, in *proto.SetRolesRequest) (*proto.SetRolesResponse, error) {
|
|
// Make sure the secrets match.
|
|
if in.Secret != m.secret {
|
|
return nil, errors.New("Secrets do not match.")
|
|
}
|
|
|
|
// Find the user.
|
|
var u models.User
|
|
if m.db.First(&u, "id = ?", in.User.Id).RowsAffected == 0 {
|
|
return nil, errors.New("User not found.")
|
|
}
|
|
|
|
// Add the roles.
|
|
var r models.Role
|
|
for _, x := range in.Roles {
|
|
if m.db.First(&r, "id = ?", x.Id).RowsAffected != 0 {
|
|
u.Roles = append(u.Roles, &r)
|
|
}
|
|
}
|
|
|
|
// Save the user.
|
|
if tx := m.db.Save(&u); tx.RowsAffected == 0 {
|
|
return nil, errors.New("Failed to add roles.")
|
|
}
|
|
|
|
return nil, nil
|
|
}
|