diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 2a7eaf2..f360824 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -13,11 +13,18 @@ import ( var ( port = flag.Int("port", 8081, "--port=8081") + domain = flag.String("domain", "http://localhost", "--domain=localhost") serverAddr = flag.String("serverAddr", "localhost", "--serverAddr=localhost") serverPort = flag.Int("serverPort", 8080, "--serverPort=8080") serverSecret = flag.String("serverSecret", "...", "--serverSecret=...") ) +func authorized(f http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + + } +} + func main() { // Parse the optional flags. flag.Parse() @@ -37,9 +44,10 @@ func main() { // Setup HTTP endpoints. http.HandleFunc("/register", gateway.Register(client)) http.HandleFunc("/login", gateway.Login(client)) - http.HandleFunc("/authorize", gateway.Authorize(client, serverSecret)) - http.HandleFunc("/reset_password", gateway.ResetPassword(client, port)) - http.HandleFunc("/change_password", gateway.ChangePassword(client)) + http.HandleFunc("/reset_password", gateway.Authorize(client, serverSecret, + gateway.ResetPassword(client, fmt.Sprintf("%s, %d", *domain, *port)))) + http.HandleFunc("/change_password", gateway.Authorize(client, serverSecret, + gateway.ChangePassword(client))) // Listen for requests. log.Printf("Forwarding from :%d to %s:%d", *port, *serverAddr, *serverPort) diff --git a/gateway/gateway.go b/gateway/gateway.go index ab17126..c6e9b16 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -1,6 +1,7 @@ package gateway import ( + "context" "fmt" "net/http" @@ -52,7 +53,7 @@ func Login(client proto.UsersClient) http.HandlerFunc { }) } -func Authorize(client proto.UsersClient, serverSecret *string) http.HandlerFunc { +func Authorize(client proto.UsersClient, serverSecret *string, next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { res, err := client.Authorize(r.Context(), &proto.AuthorizeRequest{ Secret: *serverSecret, @@ -66,12 +67,14 @@ func Authorize(client proto.UsersClient, serverSecret *string) http.HandlerFunc w.Write([]byte(err.Error())) } - w.WriteHeader(http.StatusOK) - w.Write([]byte(fmt.Sprintf("%d", res.User.Id))) + ctx := context.WithValue(r.Context(), "user", res.User) + ctx = context.WithValue(ctx, "roles", res.Roles) + + next(w, r.WithContext(ctx)) }) } -func ResetPassword(client proto.UsersClient, port *int) http.HandlerFunc { +func ResetPassword(client proto.UsersClient, endpoint string) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { res, err := client.ResetPassword(r.Context(), &proto.ResetPasswordRequest{ Form: &proto.UserForm{ @@ -86,8 +89,8 @@ func ResetPassword(client proto.UsersClient, port *int) http.HandlerFunc { w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf( - "Please follow this link to update your password: http://localhost:%d/change_password?token=%s\n", - *port, res.Token.Token))) + "Please follow this link to update your password: %s/change_password?token=%s\n", + *endpoint, res.Token.Token))) }) }