From 90c8bf2795a3b16229b8c715322ec2c0b01a23f5 Mon Sep 17 00:00:00 2001 From: Christopher James Hayward Date: Mon, 17 Jul 2023 16:00:17 -0400 Subject: [PATCH] Use chained calls for query --- server/server.go | 44 +++++++++++++++++--------------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/server/server.go b/server/server.go index c12456f..5125143 100644 --- a/server/server.go +++ b/server/server.go @@ -40,8 +40,7 @@ func (m *usersServer) Register(ctx context.Context, in *proto.RegisterRequest) ( // Check for an existing user. var u models.User - tx := m.db.First(&u, "email = ?", in.Form.Email) - if tx.RowsAffected != 0 { + if m.db.Where(&u, "email = ?", in.Form.Email).RowsAffected > 0 { return nil, errors.New("User already exists.") } @@ -56,8 +55,7 @@ func (m *usersServer) Register(ctx context.Context, in *proto.RegisterRequest) ( u.UUID = uuid.NewString() u.Email = in.Form.Email u.Password = string(bytes) - tx = m.db.Create(&u) - if tx.RowsAffected == 0 { + if m.db.Create(&u).RowsAffected == 0 { log.Fatalf("Failed to save user: %v", err) return nil, errors.New("Failed to save user.") } @@ -74,8 +72,7 @@ func (m *usersServer) Login(ctx context.Context, in *proto.LoginRequest) (*proto // Find the user. var u models.User - tx := m.db.First(&u, "email = ?", in.Form.Email) - if tx.RowsAffected == 0 { + if m.db.Where(&u, "email = ?", in.Form.Email).RowsAffected > 0 { return nil, errors.New("User not found.") } @@ -92,8 +89,7 @@ func (m *usersServer) Login(ctx context.Context, in *proto.LoginRequest) (*proto } // Save the token. - tx = m.db.Create(&s) - if tx.RowsAffected == 0 { + if m.db.Create(&s).RowsAffected == 0 { return nil, errors.New("Failed to create session.") } @@ -109,13 +105,13 @@ func (m *usersServer) Login(ctx context.Context, in *proto.LoginRequest) (*proto func (m *usersServer) Logout(ctx context.Context, in *proto.LogoutRequest) (*proto.LogoutResponse, error) { // Find the session. var s models.Session - if tx := m.db.First(&s, "token = ?", in.Token.Token); tx.RowsAffected == 0 { + if m.db.First(&s, "token = ?", in.Token.Token).RowsAffected == 0 { return nil, errors.New("Failed to find session.") } // Expire the token. s.Expires = time.Now().UnixNano() - if tx := m.db.Save(&s); tx.RowsAffected == 0 { + if m.db.Save(&s).RowsAffected == 0 { return nil, errors.New("Failed to close session.") } @@ -130,8 +126,7 @@ func (m *usersServer) Authorize(ctx context.Context, in *proto.AuthorizeRequest) // Find the session. var s models.Session - tx := m.db.First(&s, "token = ?", in.Token.Token) - if tx.RowsAffected == 0 { + if m.db.First(&s, "token = ?", in.Token.Token).RowsAffected == 0 { return nil, errors.New("Session not found.") } @@ -142,8 +137,7 @@ func (m *usersServer) Authorize(ctx context.Context, in *proto.AuthorizeRequest) // Find the user. var u models.User - tx = m.db.Model(&models.User{}).Preload("Roles").First(&u, "id = ?", s.UserID) - if tx.RowsAffected == 0 { + if m.db.Model(&models.User{}).Preload("Roles").First(&u, "id = ?", s.UserID).RowsAffected == 0 { return nil, errors.New("Failed to load roles.") } @@ -168,8 +162,7 @@ func (m *usersServer) Authorize(ctx context.Context, in *proto.AuthorizeRequest) func (m *usersServer) ResetPassword(ctx context.Context, in *proto.ResetPasswordRequest) (*proto.ResetPasswordResponse, error) { // Find the u. var u models.User - tx := m.db.First(&u, "email = ?", in.Form.Email) - if tx.RowsAffected == 0 { + if m.db.First(&u, "email = ?", in.Form.Email).RowsAffected == 0 { return nil, errors.New("User not found.") } @@ -181,8 +174,7 @@ func (m *usersServer) ResetPassword(ctx context.Context, in *proto.ResetPassword } // Save the token. - tx = m.db.Create(rt) - if tx.RowsAffected == 0 { + if m.db.Create(rt).RowsAffected == 0 { return nil, errors.New("Failed to create token.") } @@ -197,15 +189,13 @@ func (m *usersServer) ResetPassword(ctx context.Context, in *proto.ResetPassword func (m *usersServer) ChangePassword(ctx context.Context, in *proto.ChangePasswordRequest) (*proto.ChangePasswordResponse, error) { // Find the reset token. var rt models.PasswordToken - tx := m.db.First(&rt, "token = ?", in.Token.Token) - if tx.RowsAffected == 0 { + if m.db.First(&rt, "token = ?", in.Token.Token).RowsAffected == 0 { return nil, errors.New("Token not found.") } // Find the user. var u models.User - tx = m.db.First(&u, "id = ?", rt.UserID) - if tx.RowsAffected == 0 { + if m.db.First(&u, "id = ?", rt.UserID).RowsAffected == 0 { return nil, errors.New("User not found.") } @@ -217,13 +207,13 @@ func (m *usersServer) ChangePassword(ctx context.Context, in *proto.ChangePasswo } u.Password = string(bytes) - if tx = m.db.Save(u); tx.RowsAffected == 0 { + if m.db.Save(u).RowsAffected == 0 { return nil, errors.New("Failed to update password.") } // Expire current token. rt.Expires = time.Now().UnixNano() - if tx = m.db.Save(&rt); tx.RowsAffected == 0 { + if m.db.Save(&rt).RowsAffected == 0 { return nil, errors.New("Failed to update password.") } @@ -239,7 +229,7 @@ func (m *usersServer) ListRoles(ctx context.Context, in *proto.ListRolesRequest) // Get all of the available roles. var roles []models.Role - if tx := m.db.Find(&roles); tx.RowsAffected == 0 { + if m.db.Find(&roles).RowsAffected == 0 { return nil, errors.New("Failed to find roles.") } @@ -262,14 +252,14 @@ func (m *usersServer) SetRoles(ctx context.Context, in *proto.SetRolesRequest) ( // Find the user. var u models.User - if tx := m.db.First(&u, "id = ?", in.User.Id); tx.RowsAffected == 0 { + if m.db.First(&u, "id = ?", in.User.Id).RowsAffected == 0 { return nil, errors.New("User not found.") } // Add the roles. var r models.Role for _, x := range in.Roles { - if tx := m.db.First(&r, "id = ?", x.Id); tx.RowsAffected != 0 { + if m.db.First(&r, "id = ?", x.Id).RowsAffected != 0 { u.Roles = append(u.Roles, &r) } }