diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index f360824..e936c03 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -44,10 +44,9 @@ func main() { // Setup HTTP endpoints. http.HandleFunc("/register", gateway.Register(client)) http.HandleFunc("/login", gateway.Login(client)) - http.HandleFunc("/reset_password", gateway.Authorize(client, serverSecret, - gateway.ResetPassword(client, fmt.Sprintf("%s, %d", *domain, *port)))) - http.HandleFunc("/change_password", gateway.Authorize(client, serverSecret, - gateway.ChangePassword(client))) + http.HandleFunc("/logout", gateway.Authorize(client, serverSecret, gateway.Logout(client))) + http.HandleFunc("/reset_password", gateway.ResetPassword(client, fmt.Sprintf("%s, %d", *domain, *port))) + http.HandleFunc("/change_password", gateway.ChangePassword(client)) // Listen for requests. log.Printf("Forwarding from :%d to %s:%d", *port, *serverAddr, *serverPort) diff --git a/gateway/gateway.go b/gateway/gateway.go index c6e9b16..e9588c8 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -53,6 +53,21 @@ func Login(client proto.UsersClient) http.HandlerFunc { }) } +func Logout(client proto.UsersClient) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := client.Logout(r.Context(), &proto.LogoutRequest{ + Token: r.Context().Value("token").(*proto.UserToken), + }) + + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + } + + w.WriteHeader(http.StatusOK) + }) +} + func Authorize(client proto.UsersClient, serverSecret *string, next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { res, err := client.Authorize(r.Context(), &proto.AuthorizeRequest{ @@ -69,6 +84,9 @@ func Authorize(client proto.UsersClient, serverSecret *string, next http.Handler ctx := context.WithValue(r.Context(), "user", res.User) ctx = context.WithValue(ctx, "roles", res.Roles) + ctx = context.WithValue(ctx, "token", &proto.UserToken{ + Token: r.URL.Query().Get("token"), + }) next(w, r.WithContext(ctx)) }) @@ -90,7 +108,7 @@ func ResetPassword(client proto.UsersClient, endpoint string) http.HandlerFunc { w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf( "Please follow this link to update your password: %s/change_password?token=%s\n", - *endpoint, res.Token.Token))) + endpoint, res.Token.Token))) }) }