Browse Source

Use gorm db

master
parent
commit
f6346f5c7b
Signed by: chris GPG Key ID: 3025DCBD46F81C0F
  1. 27
      cmd/server/main.go
  2. 9
      go.mod
  3. 29
      go.sum
  4. 7
      server/models/password_token.go
  5. 10
      server/models/session.go
  6. 10
      server/models/user.go
  7. 114
      server/server.go
  8. 60
      server/token_db.go
  9. 62
      server/user_db.go

27
cmd/server/main.go

@ -9,12 +9,22 @@ import (
"git.chrishayward.xyz/x/users/server" "git.chrishayward.xyz/x/users/server"
"github.com/google/uuid" "github.com/google/uuid"
"google.golang.org/grpc" "google.golang.org/grpc"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
) )
var ( var (
port = flag.Uint("port", 8080, "--port=8080")
secretDefault = uuid.NewString() secretDefault = uuid.NewString()
secret = flag.String("secret", secretDefault, "--secret=SECRET") secret = flag.String("secret", secretDefault, "--secret=SECRET")
port = flag.Uint("port", 8080, "--port=8080")
dbType = flag.String("dbType", "sqlite", "--dbType=sqlite,postgress")
dbFile = flag.String("dbFile", "users.db", "--dbFile=users.db")
dbHost = flag.String("dbHost", "localhost", "--dbHost=localhost")
dbPort = flag.Uint("dbPort", 5432, "--dbPort=5432")
dbName = flag.String("dbName", "postgres", "--dbName=postgres")
dbUser = flag.String("dbUser", "postgres", "--dbUser=postgres")
dbPass = flag.String("dbPass", "postgres", "--dbPass=postgres")
) )
func main() { func main() {
@ -26,6 +36,19 @@ func main() {
fmt.Printf("SECRET=%s\n", secretDefault) fmt.Printf("SECRET=%s\n", secretDefault)
} }
// Initialize the database.
var db *gorm.DB
var err error
switch *dbType {
case "postgres":
db, _ = gorm.Open(postgres.Open(fmt.Sprintf(
"host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
*dbHost, *dbUser, *dbPass, *dbName, *dbPort)))
case "sqlite":
default:
db, _ = gorm.Open(sqlite.Open(*dbFile), &gorm.Config{})
}
// Create the network listener. // Create the network listener.
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
if err != nil { if err != nil {
@ -34,7 +57,7 @@ func main() {
// Start listening for requests. // Start listening for requests.
srv := grpc.NewServer() srv := grpc.NewServer()
proto.RegisterUsersServer(srv, server.NewUsersServer(secret))
proto.RegisterUsersServer(srv, server.NewUsersServer(secret, db))
fmt.Printf("Listening on :%d", *port) fmt.Printf("Listening on :%d", *port)
srv.Serve(lis) srv.Serve(lis)
} }

9
go.mod

@ -7,10 +7,19 @@ require (
golang.org/x/crypto v0.10.0 golang.org/x/crypto v0.10.0
google.golang.org/grpc v1.56.1 google.golang.org/grpc v1.56.1
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.5.2
gorm.io/gorm v1.25.2
) )
require ( require (
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect
golang.org/x/net v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.9.0 // indirect golang.org/x/sys v0.9.0 // indirect
golang.org/x/text v0.10.0 // indirect golang.org/x/text v0.10.0 // indirect

29
go.sum

@ -1,3 +1,5 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
@ -5,6 +7,24 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
@ -22,3 +42,12 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
gorm.io/driver/sqlite v1.5.2 h1:TpQ+/dqCY4uCigCFyrfnrJnrW9zjpelWVoEVNy5qJkc=
gorm.io/driver/sqlite v1.5.2/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=

7
server/models/password_token.go

@ -0,0 +1,7 @@
package models
type PasswordToken struct {
Token string
Expires int64
UserID uint
}

10
server/models/session.go

@ -0,0 +1,10 @@
package models
import "gorm.io/gorm"
type Session struct {
gorm.Model
Token string
Expires int64
UserID uint
}

10
server/models/user.go

@ -0,0 +1,10 @@
package models
import "gorm.io/gorm"
type User struct {
gorm.Model
Email string
Password string
Sessions []Session
}

114
server/server.go

@ -8,24 +8,23 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"git.chrishayward.xyz/x/users/proto" "git.chrishayward.xyz/x/users/proto"
"git.chrishayward.xyz/x/users/server/models"
) )
type usersServer struct { type usersServer struct {
proto.UsersServer proto.UsersServer
secret *string
users UserDB
tokens TokenDB
resetTokens TokenDB
secret *string
db *gorm.DB
} }
func NewUsersServer(secret *string) proto.UsersServer {
func NewUsersServer(secret *string, db *gorm.DB) proto.UsersServer {
db.AutoMigrate(&models.User{}, &models.Session{}, &models.PasswordToken{})
return &usersServer{ return &usersServer{
secret: secret,
users: newInMemoryUserDB(),
tokens: newInMemoryTokenDB(),
resetTokens: newInMemoryTokenDB(),
secret: secret,
db: db,
} }
} }
@ -40,7 +39,9 @@ func (m *usersServer) Register(ctx context.Context, in *proto.RegisterRequest) (
} }
// Check for an existing user. // Check for an existing user.
if u, _ := m.users.FindByEmail(in.Form.Email); u != nil {
var user models.User
tx := m.db.First(&user, "email = ?", in.Form.Email)
if tx.RowsAffected != 0 {
return nil, errors.New("User already exists.") return nil, errors.New("User already exists.")
} }
@ -52,10 +53,10 @@ func (m *usersServer) Register(ctx context.Context, in *proto.RegisterRequest) (
} }
// Create the new user. // Create the new user.
if err := m.users.Save(&User{
Email: in.Form.Email,
Password: string(bytes),
}); err != nil {
user.Email = in.Form.Email
user.Password = string(bytes)
tx = m.db.Create(&user)
if tx.RowsAffected == 0 {
log.Fatalf("Failed to save user: %v", err) log.Fatalf("Failed to save user: %v", err)
return nil, errors.New("Failed to save user.") return nil, errors.New("Failed to save user.")
} }
@ -71,9 +72,10 @@ func (m *usersServer) Login(ctx context.Context, in *proto.LoginRequest) (*proto
} }
// Find the user. // Find the user.
user, err := m.users.FindByEmail(in.Form.Email)
if err != nil {
return nil, err
var user models.User
tx := m.db.First(&user, "email = ?", in.Form.Email)
if tx.RowsAffected == 0 {
return nil, errors.New("User not found.")
} }
// Compare the passwords. // Compare the passwords.
@ -81,23 +83,24 @@ func (m *usersServer) Login(ctx context.Context, in *proto.LoginRequest) (*proto
return nil, errors.New("Passwords do not match.") return nil, errors.New("Passwords do not match.")
} }
// Create a token.
expires := time.Now().AddDate(0, 0, 1)
token := &Token{
UserID: user.ID,
// Create a session.
session := &models.Session{
Token: uuid.NewString(), Token: uuid.NewString(),
Expires: &expires,
Expires: time.Now().AddDate(0, 0, 1).UnixNano(),
UserID: user.ID,
} }
// Save the token. // Save the token.
m.tokens.Save(token)
tx = m.db.Create(&session)
if tx.RowsAffected == 0 {
return nil, errors.New("Failed to create session.")
}
// Return the response. // Return the response.
expiresNano := expires.UnixNano()
return &proto.LoginResponse{ return &proto.LoginResponse{
Token: &proto.UserToken{ Token: &proto.UserToken{
Token: token.Token,
Expires: &expiresNano,
Token: session.Token,
Expires: &session.Expires,
}, },
}, nil }, nil
} }
@ -108,64 +111,68 @@ func (m *usersServer) Authorize(ctx context.Context, in *proto.AuthorizeRequest)
return nil, errors.New("Secrets do not match.") return nil, errors.New("Secrets do not match.")
} }
// Find the token.
token, err := m.tokens.FindByToken(in.Token.Token)
if err != nil {
return nil, err
// Find the session.
var session models.Session
tx := m.db.First(&session, "token = ?", in.Token.Token)
if tx.RowsAffected == 0 {
return nil, errors.New("Session not found.")
} }
// Make sure the token hasn't expired.
if token.Expires.After(time.Now()) {
// Make sure the session hasn't expired.
if time.Now().UnixNano() > session.Expires {
return nil, errors.New("Token is expired.") return nil, errors.New("Token is expired.")
} }
// Return the user ID. // Return the user ID.
return &proto.AuthorizeResponse{ return &proto.AuthorizeResponse{
User: &proto.UserInfo{ User: &proto.UserInfo{
Id: int64(token.UserID),
Id: int64(session.UserID),
}, },
}, nil }, nil
} }
func (m *usersServer) ResetPassword(ctx context.Context, in *proto.ResetPasswordRequest) (*proto.ResetPasswordResponse, error) { func (m *usersServer) ResetPassword(ctx context.Context, in *proto.ResetPasswordRequest) (*proto.ResetPasswordResponse, error) {
// Find the user. // Find the user.
user, err := m.users.FindByEmail(in.Form.Email)
if err != nil {
return nil, err
var user models.User
tx := m.db.First(&user, "email = ?", in.Form.Email)
if tx.RowsAffected == 0 {
return nil, errors.New("User not found.")
} }
// Generate a reset token. // Generate a reset token.
expires := time.Now().AddDate(0, 0, 1)
token := &Token{
resetToken := &models.PasswordToken{
UserID: user.ID, UserID: user.ID,
Token: uuid.NewString(), Token: uuid.NewString(),
Expires: &expires,
Expires: time.Now().UnixNano(),
} }
// Save the token. // Save the token.
if err := m.resetTokens.Save(token); err != nil {
return nil, err
tx = m.db.Create(resetToken)
if tx.RowsAffected == 0 {
return nil, errors.New("Failed to create token.")
} }
// Return the response. // Return the response.
return &proto.ResetPasswordResponse{ return &proto.ResetPasswordResponse{
Token: &proto.UserToken{ Token: &proto.UserToken{
Token: token.Token,
Token: resetToken.Token,
}, },
}, nil }, nil
} }
func (m *usersServer) ChangePassword(ctx context.Context, in *proto.ChangePasswordRequest) (*proto.ChangePasswordResponse, error) { func (m *usersServer) ChangePassword(ctx context.Context, in *proto.ChangePasswordRequest) (*proto.ChangePasswordResponse, error) {
// Find the reset token. // Find the reset token.
resetToken, err := m.resetTokens.FindByToken(in.Token.Token)
if err != nil {
return nil, err
var resetToken models.PasswordToken
tx := m.db.First(&resetToken, "token = ?", in.Token.Token)
if tx.RowsAffected == 0 {
return nil, errors.New("Token not found.")
} }
// Find the user. // Find the user.
user, err := m.users.FindByID(resetToken.UserID)
if err != nil {
return nil, err
var user models.User
tx = m.db.First(&user, "id = ?", resetToken.UserID)
if tx.RowsAffected == 0 {
return nil, errors.New("User not found.")
} }
// Update the password. // Update the password.
@ -176,15 +183,14 @@ func (m *usersServer) ChangePassword(ctx context.Context, in *proto.ChangePasswo
} }
user.Password = string(bytes) user.Password = string(bytes)
if err := m.users.Save(user); err != nil {
return nil, err
if tx = m.db.Save(user); tx.RowsAffected == 0 {
return nil, errors.New("Failed to update password.")
} }
// Expire current token. // Expire current token.
if token, err := m.tokens.FindByUserID(user.ID); token != nil && err == nil {
expires := time.Now()
token.Expires = &expires
_ = m.tokens.Save(token)
resetToken.Expires = time.Now().UnixNano()
if tx = m.db.Save(&resetToken); tx.RowsAffected == 0 {
return nil, errors.New("Failed to update password.")
} }
// Return the response. // Return the response.

60
server/token_db.go

@ -1,60 +0,0 @@
package server
import (
"errors"
"time"
)
type Token struct {
UserID int64
Token string
Expires *time.Time
}
type TokenDB interface {
FindByUserID(id int64) (*Token, error)
FindByToken(token string) (*Token, error)
Save(token *Token) error
}
type inMemoryTokenDB struct {
tokens []*Token
}
func newInMemoryTokenDB() *inMemoryTokenDB {
return &inMemoryTokenDB{
tokens: make([]*Token, 0),
}
}
func (m *inMemoryTokenDB) FindByUserID(id int64) (*Token, error) {
for _, t := range m.tokens {
if t.UserID == id {
return t, nil
}
}
return nil, errors.New("Token not found.")
}
func (m *inMemoryTokenDB) FindByToken(token string) (*Token, error) {
for _, t := range m.tokens {
if t.Token == token {
return t, nil
}
}
return nil, errors.New("Token not found.")
}
func (m *inMemoryTokenDB) Save(token *Token) error {
for i, t := range m.tokens {
if t.UserID == token.UserID || t.Token == token.Token {
m.tokens[i] = token
return nil
}
}
m.tokens = append(m.tokens, token)
return nil
}

62
server/user_db.go

@ -1,62 +0,0 @@
package server
import "errors"
type User struct {
ID int64
Email string
Password string
}
type UserDB interface {
FindByID(id int64) (*User, error)
FindByEmail(email string) (*User, error)
Save(*User) error
}
type inMemoryUserDB struct {
UserDB
nextID int64
users []*User
}
func newInMemoryUserDB() *inMemoryUserDB {
return &inMemoryUserDB{
nextID: 1,
users: make([]*User, 0),
}
}
func (m *inMemoryUserDB) FindByID(id int64) (*User, error) {
for _, u := range m.users {
if u.ID == id {
return u, nil
}
}
return nil, errors.New("User not found.")
}
func (m *inMemoryUserDB) FindByEmail(email string) (*User, error) {
for _, u := range m.users {
if u.Email == email {
return u, nil
}
}
return nil, errors.New("User not found.")
}
func (m *inMemoryUserDB) Save(user *User) error {
for i, u := range m.users {
if u.ID == user.ID || u.Email == user.Email {
m.users[i] = user
return nil
}
}
user.ID = m.nextID
m.users = append(m.users, user)
m.nextID++
return nil
}
Loading…
Cancel
Save