diff --git a/cmd/server/main.go b/cmd/server/main.go index 1308101..661423f 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -9,12 +9,22 @@ import ( "git.chrishayward.xyz/x/users/server" "github.com/google/uuid" "google.golang.org/grpc" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" ) var ( + port = flag.Uint("port", 8080, "--port=8080") secretDefault = uuid.NewString() secret = flag.String("secret", secretDefault, "--secret=SECRET") - port = flag.Uint("port", 8080, "--port=8080") + dbType = flag.String("dbType", "sqlite", "--dbType=sqlite,postgress") + dbFile = flag.String("dbFile", "users.db", "--dbFile=users.db") + dbHost = flag.String("dbHost", "localhost", "--dbHost=localhost") + dbPort = flag.Uint("dbPort", 5432, "--dbPort=5432") + dbName = flag.String("dbName", "postgres", "--dbName=postgres") + dbUser = flag.String("dbUser", "postgres", "--dbUser=postgres") + dbPass = flag.String("dbPass", "postgres", "--dbPass=postgres") ) func main() { @@ -26,6 +36,19 @@ func main() { fmt.Printf("SECRET=%s\n", secretDefault) } + // Initialize the database. + var db *gorm.DB + var err error + switch *dbType { + case "postgres": + db, _ = gorm.Open(postgres.Open(fmt.Sprintf( + "host=%s user=%s password=%s dbname=%s port=%d sslmode=disable", + *dbHost, *dbUser, *dbPass, *dbName, *dbPort))) + case "sqlite": + default: + db, _ = gorm.Open(sqlite.Open(*dbFile), &gorm.Config{}) + } + // Create the network listener. lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) if err != nil { @@ -34,7 +57,7 @@ func main() { // Start listening for requests. srv := grpc.NewServer() - proto.RegisterUsersServer(srv, server.NewUsersServer(secret)) + proto.RegisterUsersServer(srv, server.NewUsersServer(secret, db)) fmt.Printf("Listening on :%d", *port) srv.Serve(lis) } diff --git a/go.mod b/go.mod index 453ed1b..2104887 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,19 @@ require ( golang.org/x/crypto v0.10.0 google.golang.org/grpc v1.56.1 google.golang.org/protobuf v1.31.0 + gorm.io/driver/postgres v1.5.2 + gorm.io/driver/sqlite v1.5.2 + gorm.io/gorm v1.25.2 ) require ( github.com/golang/protobuf v1.5.3 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.3.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/mattn/go-sqlite3 v1.14.17 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.9.0 // indirect golang.org/x/text v0.10.0 // indirect diff --git a/go.sum b/go.sum index 354baa0..807e326 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= @@ -5,6 +7,24 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= +github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= @@ -22,3 +42,12 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= +gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= +gorm.io/driver/sqlite v1.5.2 h1:TpQ+/dqCY4uCigCFyrfnrJnrW9zjpelWVoEVNy5qJkc= +gorm.io/driver/sqlite v1.5.2/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= +gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho= +gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= diff --git a/server/models/password_token.go b/server/models/password_token.go new file mode 100644 index 0000000..1dc5c78 --- /dev/null +++ b/server/models/password_token.go @@ -0,0 +1,7 @@ +package models + +type PasswordToken struct { + Token string + Expires int64 + UserID uint +} diff --git a/server/models/session.go b/server/models/session.go new file mode 100644 index 0000000..7d3576a --- /dev/null +++ b/server/models/session.go @@ -0,0 +1,10 @@ +package models + +import "gorm.io/gorm" + +type Session struct { + gorm.Model + Token string + Expires int64 + UserID uint +} diff --git a/server/models/user.go b/server/models/user.go new file mode 100644 index 0000000..96612ce --- /dev/null +++ b/server/models/user.go @@ -0,0 +1,10 @@ +package models + +import "gorm.io/gorm" + +type User struct { + gorm.Model + Email string + Password string + Sessions []Session +} diff --git a/server/server.go b/server/server.go index 6f6d623..3c9d124 100644 --- a/server/server.go +++ b/server/server.go @@ -8,24 +8,23 @@ import ( "github.com/google/uuid" "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" "git.chrishayward.xyz/x/users/proto" + "git.chrishayward.xyz/x/users/server/models" ) type usersServer struct { proto.UsersServer - secret *string - users UserDB - tokens TokenDB - resetTokens TokenDB + secret *string + db *gorm.DB } -func NewUsersServer(secret *string) proto.UsersServer { +func NewUsersServer(secret *string, db *gorm.DB) proto.UsersServer { + db.AutoMigrate(&models.User{}, &models.Session{}, &models.PasswordToken{}) return &usersServer{ - secret: secret, - users: newInMemoryUserDB(), - tokens: newInMemoryTokenDB(), - resetTokens: newInMemoryTokenDB(), + secret: secret, + db: db, } } @@ -40,7 +39,9 @@ func (m *usersServer) Register(ctx context.Context, in *proto.RegisterRequest) ( } // Check for an existing user. - if u, _ := m.users.FindByEmail(in.Form.Email); u != nil { + var user models.User + tx := m.db.First(&user, "email = ?", in.Form.Email) + if tx.RowsAffected != 0 { return nil, errors.New("User already exists.") } @@ -52,10 +53,10 @@ func (m *usersServer) Register(ctx context.Context, in *proto.RegisterRequest) ( } // Create the new user. - if err := m.users.Save(&User{ - Email: in.Form.Email, - Password: string(bytes), - }); err != nil { + user.Email = in.Form.Email + user.Password = string(bytes) + tx = m.db.Create(&user) + if tx.RowsAffected == 0 { log.Fatalf("Failed to save user: %v", err) return nil, errors.New("Failed to save user.") } @@ -71,9 +72,10 @@ func (m *usersServer) Login(ctx context.Context, in *proto.LoginRequest) (*proto } // Find the user. - user, err := m.users.FindByEmail(in.Form.Email) - if err != nil { - return nil, err + var user models.User + tx := m.db.First(&user, "email = ?", in.Form.Email) + if tx.RowsAffected == 0 { + return nil, errors.New("User not found.") } // Compare the passwords. @@ -81,23 +83,24 @@ func (m *usersServer) Login(ctx context.Context, in *proto.LoginRequest) (*proto return nil, errors.New("Passwords do not match.") } - // Create a token. - expires := time.Now().AddDate(0, 0, 1) - token := &Token{ - UserID: user.ID, + // Create a session. + session := &models.Session{ Token: uuid.NewString(), - Expires: &expires, + Expires: time.Now().AddDate(0, 0, 1).UnixNano(), + UserID: user.ID, } // Save the token. - m.tokens.Save(token) + tx = m.db.Create(&session) + if tx.RowsAffected == 0 { + return nil, errors.New("Failed to create session.") + } // Return the response. - expiresNano := expires.UnixNano() return &proto.LoginResponse{ Token: &proto.UserToken{ - Token: token.Token, - Expires: &expiresNano, + Token: session.Token, + Expires: &session.Expires, }, }, nil } @@ -108,64 +111,68 @@ func (m *usersServer) Authorize(ctx context.Context, in *proto.AuthorizeRequest) return nil, errors.New("Secrets do not match.") } - // Find the token. - token, err := m.tokens.FindByToken(in.Token.Token) - if err != nil { - return nil, err + // Find the session. + var session models.Session + tx := m.db.First(&session, "token = ?", in.Token.Token) + if tx.RowsAffected == 0 { + return nil, errors.New("Session not found.") } - // Make sure the token hasn't expired. - if token.Expires.After(time.Now()) { + // Make sure the session hasn't expired. + if time.Now().UnixNano() > session.Expires { return nil, errors.New("Token is expired.") } // Return the user ID. return &proto.AuthorizeResponse{ User: &proto.UserInfo{ - Id: int64(token.UserID), + Id: int64(session.UserID), }, }, nil } func (m *usersServer) ResetPassword(ctx context.Context, in *proto.ResetPasswordRequest) (*proto.ResetPasswordResponse, error) { // Find the user. - user, err := m.users.FindByEmail(in.Form.Email) - if err != nil { - return nil, err + var user models.User + tx := m.db.First(&user, "email = ?", in.Form.Email) + if tx.RowsAffected == 0 { + return nil, errors.New("User not found.") } // Generate a reset token. - expires := time.Now().AddDate(0, 0, 1) - token := &Token{ + resetToken := &models.PasswordToken{ UserID: user.ID, Token: uuid.NewString(), - Expires: &expires, + Expires: time.Now().UnixNano(), } // Save the token. - if err := m.resetTokens.Save(token); err != nil { - return nil, err + tx = m.db.Create(resetToken) + if tx.RowsAffected == 0 { + return nil, errors.New("Failed to create token.") } // Return the response. return &proto.ResetPasswordResponse{ Token: &proto.UserToken{ - Token: token.Token, + Token: resetToken.Token, }, }, nil } func (m *usersServer) ChangePassword(ctx context.Context, in *proto.ChangePasswordRequest) (*proto.ChangePasswordResponse, error) { // Find the reset token. - resetToken, err := m.resetTokens.FindByToken(in.Token.Token) - if err != nil { - return nil, err + var resetToken models.PasswordToken + tx := m.db.First(&resetToken, "token = ?", in.Token.Token) + if tx.RowsAffected == 0 { + return nil, errors.New("Token not found.") } // Find the user. - user, err := m.users.FindByID(resetToken.UserID) - if err != nil { - return nil, err + var user models.User + tx = m.db.First(&user, "id = ?", resetToken.UserID) + if tx.RowsAffected == 0 { + return nil, errors.New("User not found.") } // Update the password. @@ -176,15 +183,14 @@ func (m *usersServer) ChangePassword(ctx context.Context, in *proto.ChangePasswo } user.Password = string(bytes) - if err := m.users.Save(user); err != nil { - return nil, err + if tx = m.db.Save(user); tx.RowsAffected == 0 { + return nil, errors.New("Failed to update password.") } // Expire current token. - if token, err := m.tokens.FindByUserID(user.ID); token != nil && err == nil { - expires := time.Now() - token.Expires = &expires - _ = m.tokens.Save(token) + resetToken.Expires = time.Now().UnixNano() + if tx = m.db.Save(&resetToken); tx.RowsAffected == 0 { + return nil, errors.New("Failed to update password.") } // Return the response. diff --git a/server/token_db.go b/server/token_db.go deleted file mode 100644 index d16acb8..0000000 --- a/server/token_db.go +++ /dev/null @@ -1,60 +0,0 @@ -package server - -import ( - "errors" - "time" -) - -type Token struct { - UserID int64 - Token string - Expires *time.Time -} - -type TokenDB interface { - FindByUserID(id int64) (*Token, error) - FindByToken(token string) (*Token, error) - Save(token *Token) error -} - -type inMemoryTokenDB struct { - tokens []*Token -} - -func newInMemoryTokenDB() *inMemoryTokenDB { - return &inMemoryTokenDB{ - tokens: make([]*Token, 0), - } -} - -func (m *inMemoryTokenDB) FindByUserID(id int64) (*Token, error) { - for _, t := range m.tokens { - if t.UserID == id { - return t, nil - } - } - - return nil, errors.New("Token not found.") -} - -func (m *inMemoryTokenDB) FindByToken(token string) (*Token, error) { - for _, t := range m.tokens { - if t.Token == token { - return t, nil - } - } - - return nil, errors.New("Token not found.") -} - -func (m *inMemoryTokenDB) Save(token *Token) error { - for i, t := range m.tokens { - if t.UserID == token.UserID || t.Token == token.Token { - m.tokens[i] = token - return nil - } - } - - m.tokens = append(m.tokens, token) - return nil -} diff --git a/server/user_db.go b/server/user_db.go deleted file mode 100644 index 3fa86ef..0000000 --- a/server/user_db.go +++ /dev/null @@ -1,62 +0,0 @@ -package server - -import "errors" - -type User struct { - ID int64 - Email string - Password string -} - -type UserDB interface { - FindByID(id int64) (*User, error) - FindByEmail(email string) (*User, error) - Save(*User) error -} - -type inMemoryUserDB struct { - UserDB - nextID int64 - users []*User -} - -func newInMemoryUserDB() *inMemoryUserDB { - return &inMemoryUserDB{ - nextID: 1, - users: make([]*User, 0), - } -} - -func (m *inMemoryUserDB) FindByID(id int64) (*User, error) { - for _, u := range m.users { - if u.ID == id { - return u, nil - } - } - - return nil, errors.New("User not found.") -} - -func (m *inMemoryUserDB) FindByEmail(email string) (*User, error) { - for _, u := range m.users { - if u.Email == email { - return u, nil - } - } - - return nil, errors.New("User not found.") -} - -func (m *inMemoryUserDB) Save(user *User) error { - for i, u := range m.users { - if u.ID == user.ID || u.Email == user.Email { - m.users[i] = user - return nil - } - } - - user.ID = m.nextID - m.users = append(m.users, user) - m.nextID++ - return nil -}