From 27e225795cf81386495fda05fa89a4dae3f5f51e Mon Sep 17 00:00:00 2001 From: Jan Dittberner Date: Thu, 31 Dec 2020 13:19:21 +0100 Subject: [PATCH] Refactor app, implement logout --- .gitignore | 3 +- app/handlers/after_logout.go | 33 +++ app/handlers/common.go | 49 ++++ app/handlers/index.go | 79 +++++++ app/handlers/oidc_callback.go | 117 ++++++++++ app/services/session.go | 25 ++ cmd/app/main.go | 422 ++++++++++------------------------ cmd/idp/main.go | 51 ++-- common/services/oidc.go | 51 ++++ common/services/security.go | 19 ++ idp/handlers/consent.go | 16 +- idp/handlers/error.go | 20 ++ idp/handlers/login.go | 53 +++-- idp/handlers/logout.go | 58 +++++ 14 files changed, 647 insertions(+), 349 deletions(-) create mode 100644 app/handlers/after_logout.go create mode 100644 app/handlers/common.go create mode 100644 app/handlers/index.go create mode 100644 app/handlers/oidc_callback.go create mode 100644 app/services/session.go create mode 100644 common/services/oidc.go create mode 100644 common/services/security.go create mode 100644 idp/handlers/error.go create mode 100644 idp/handlers/logout.go diff --git a/.gitignore b/.gitignore index dc675de..b63d656 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ -/.idea/ /*.toml +/.idea/ /certs/ +/sessions/ diff --git a/app/handlers/after_logout.go b/app/handlers/after_logout.go new file mode 100644 index 0000000..bb6650b --- /dev/null +++ b/app/handlers/after_logout.go @@ -0,0 +1,33 @@ +package handlers + +import ( + "net/http" + + "github.com/sirupsen/logrus" + + "git.cacert.org/oidc_login/app/services" +) + +type afterLogoutHandler struct { + logger *logrus.Logger +} + +func (h *afterLogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + session, err := services.GetSessionStore().Get(r, sessionName) + if err != nil { + h.logger.Errorf("could not get session: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + session.Options.MaxAge = -1 + if err = session.Save(r, w); err != nil { + h.logger.Errorf("could not save session: %v", err) + } + + w.Header().Set("Location", "/") + w.WriteHeader(http.StatusFound) +} + +func NewAfterLogoutHandler(logger *logrus.Logger) *afterLogoutHandler { + return &afterLogoutHandler{logger: logger} +} diff --git a/app/handlers/common.go b/app/handlers/common.go new file mode 100644 index 0000000..59a1a53 --- /dev/null +++ b/app/handlers/common.go @@ -0,0 +1,49 @@ +package handlers + +import ( + "encoding/base64" + "net/http" + "net/url" + + "golang.org/x/oauth2" + + "git.cacert.org/oidc_login/app/services" + commonServices "git.cacert.org/oidc_login/common/services" +) + +const sessionName = "resource_session" + +func Authenticate(oauth2Config *oauth2.Config, clientId string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, err := services.GetSessionStore().Get(r, sessionName) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if _, ok := session.Values[sessionKeyUserId]; ok { + next.ServeHTTP(w, r) + return + } + session.Values[sessionRedirectTarget] = r.URL.String() + if err = session.Save(r, w); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + var authUrl *url.URL + if authUrl, err = url.Parse(oauth2Config.Endpoint.AuthURL); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + queryValues := authUrl.Query() + queryValues.Set("client_id", clientId) + queryValues.Set("response_type", "code") + queryValues.Set("scope", "openid offline_access profile email") + queryValues.Set("state", base64.URLEncoding.EncodeToString(commonServices.GenerateKey(8))) + authUrl.RawQuery = queryValues.Encode() + + w.Header().Set("Location", authUrl.String()) + w.WriteHeader(http.StatusFound) + }) + } +} diff --git a/app/handlers/index.go b/app/handlers/index.go new file mode 100644 index 0000000..293bc9b --- /dev/null +++ b/app/handlers/index.go @@ -0,0 +1,79 @@ +package handlers + +import ( + "fmt" + "html/template" + "net/http" + "net/url" + + "git.cacert.org/oidc_login/app/services" +) + +type indexHandler struct { + logoutUrl string + serverAddr string +} + +func (h *indexHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if request.Method != http.MethodGet { + http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + if request.URL.Path != "/" { + http.NotFound(writer, request) + return + } + writer.WriteHeader(http.StatusOK) + + page, err := template.New("").Parse(` + + +Auth test + +

Hello {{ .User }}

+

This is an authorization protected resource

+Logout + + +`) + if err != nil { + http.Error(writer, err.Error(), http.StatusInternalServerError) + return + } + session, err := services.GetSessionStore().Get(request, sessionName) + if err != nil { + http.Error(writer, err.Error(), http.StatusInternalServerError) + return + } + + logoutUrl, err := url.Parse(h.logoutUrl) + if err != nil { + http.Error(writer, err.Error(), http.StatusInternalServerError) + return + } + var user string + var ok bool + if user, ok = session.Values[sessionKeyUsername].(string); ok { + + } + if idToken, ok := session.Values[sessionKeyIdToken].(string); ok { + logoutUrl.RawQuery = url.Values{ + "id_token_hint": []string{idToken}, + "post_logout_redirect_uri": []string{fmt.Sprintf("https://%s/after-logout", h.serverAddr)}, + }.Encode() + } + + writer.Header().Add("Content-Type", "text/html") + err = page.Execute(writer, map[string]interface{}{ + "User": user, + "LogoutURL": logoutUrl.String(), + }) + if err != nil { + http.Error(writer, err.Error(), http.StatusInternalServerError) + return + } +} + +func NewIndexHandler(logoutUrl string, serverAddr string) *indexHandler { + return &indexHandler{logoutUrl: logoutUrl, serverAddr: serverAddr} +} diff --git a/app/handlers/oidc_callback.go b/app/handlers/oidc_callback.go new file mode 100644 index 0000000..cf46dc1 --- /dev/null +++ b/app/handlers/oidc_callback.go @@ -0,0 +1,117 @@ +package handlers + +import ( + "context" + "net/http" + + "github.com/go-openapi/runtime/client" + "github.com/lestrrat-go/jwx/jwk" + "github.com/lestrrat-go/jwx/jwt" + "github.com/sirupsen/logrus" + "golang.org/x/oauth2" + + "git.cacert.org/oidc_login/app/services" +) + +const ( + sessionKeyAccessToken = iota + sessionKeyRefreshToken + sessionKeyIdToken + sessionKeyUserId + sessionKeyRoles + sessionKeyEmail + sessionKeyUsername + sessionRedirectTarget +) + +type oidcCallbackHandler struct { + keySet *jwk.Set + oauth2Config *oauth2.Config +} + +func (c *oidcCallbackHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if request.Method != http.MethodGet { + http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + if request.URL.Path != "/callback" { + http.NotFound(writer, request) + return + } + + code := request.URL.Query().Get("code") + + ctx := context.Background() + httpClient, err := client.TLSClient(client.TLSClientOptions{InsecureSkipVerify: true}) + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + tok, err := c.oauth2Config.Exchange(ctx, code) + if err != nil { + logrus.Error(err) + http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + session, err := services.GetSessionStore().Get(request, "resource_session") + if err != nil { + http.Error(writer, err.Error(), http.StatusInternalServerError) + return + } + + session.Values[sessionKeyAccessToken] = tok.AccessToken + session.Values[sessionKeyRefreshToken] = tok.RefreshToken + session.Values[sessionKeyIdToken] = tok.Extra("id_token").(string) + + idToken := tok.Extra("id_token") + if parsedIdToken, err := jwt.ParseString(idToken.(string), jwt.WithKeySet(c.keySet), jwt.WithOpenIDClaims()); err != nil { + logrus.Error(err) + http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } else { + logrus.Infof(` +ID Token +======== + +Subject: %s +Audience: %s +Issued at: %s +Issued by: %s +Not valid before: %s +Not valid after: %s + +`, + parsedIdToken.Subject(), + parsedIdToken.Audience(), + parsedIdToken.IssuedAt(), + parsedIdToken.Issuer(), + parsedIdToken.NotBefore(), + parsedIdToken.Expiration(), + ) + + session.Values[sessionKeyUserId] = parsedIdToken.Subject() + + if roles, ok := parsedIdToken.Get("Groups"); ok { + session.Values[sessionKeyRoles] = roles + } + if username, ok := parsedIdToken.Get("Username"); ok { + session.Values[sessionKeyUsername] = username + } + if email, ok := parsedIdToken.Get("Email"); ok { + session.Values[sessionKeyEmail] = email + } + } + if err = session.Save(request, writer); err != nil { + http.Error(writer, err.Error(), http.StatusInternalServerError) + } + if redirectTarget, ok := session.Values[sessionRedirectTarget]; ok { + writer.Header().Set("Location", redirectTarget.(string)) + } else { + writer.Header().Set("Location", "/") + } + + writer.WriteHeader(http.StatusFound) +} + +func NewCallbackHandler(keySet *jwk.Set, oauth2Config *oauth2.Config) *oidcCallbackHandler { + return &oidcCallbackHandler{keySet: keySet, oauth2Config: oauth2Config} +} diff --git a/app/services/session.go b/app/services/session.go new file mode 100644 index 0000000..04170b2 --- /dev/null +++ b/app/services/session.go @@ -0,0 +1,25 @@ +package services + +import ( + "os" + + "github.com/gorilla/sessions" + log "github.com/sirupsen/logrus" +) + +var store *sessions.FilesystemStore + +func InitSessionStore(logger *log.Logger, sessionPath string, keys ...[]byte) { + store = sessions.NewFilesystemStore(sessionPath, keys...) + if _, err := os.Stat(sessionPath); err != nil { + if os.IsNotExist(err) { + if err = os.MkdirAll(sessionPath, 0700); err != nil { + logger.Fatalf("could not create session store directory: %s", err) + } + } + } +} + +func GetSessionStore() *sessions.FilesystemStore { + return store +} diff --git a/cmd/app/main.go b/cmd/app/main.go index bce2ad8..7d7349d 100644 --- a/cmd/app/main.go +++ b/cmd/app/main.go @@ -1,111 +1,117 @@ package main import ( - "bytes" "context" - "crypto/rand" + "crypto/tls" "encoding/base64" - "encoding/json" - "html/template" + "fmt" "net/http" - "net/url" "os" + "os/signal" "strings" + "sync/atomic" + "time" - "github.com/go-openapi/runtime/client" - "github.com/gorilla/sessions" "github.com/knadh/koanf" - jsonParser "github.com/knadh/koanf/parsers/json" + "github.com/knadh/koanf/parsers/toml" "github.com/knadh/koanf/providers/confmap" "github.com/knadh/koanf/providers/env" "github.com/knadh/koanf/providers/file" + "github.com/knadh/koanf/providers/posflag" "github.com/lestrrat-go/jwx/jwk" - "github.com/lestrrat-go/jwx/jwt" log "github.com/sirupsen/logrus" + flag "github.com/spf13/pflag" "golang.org/x/oauth2" -) -type OpenIDConfiguration struct { - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - JwksUri string `json:"jwks_uri"` - EndSessionEndpoint string `json:"end_session_endpoint"` -} - -var ( - sessionStore *sessions.FilesystemStore - k = koanf.New(".") -) - -const ( - sessionKeyAccessToken = iota - sessionKeyRefreshToken - sessionKeyIdToken - sessionKeyUserId - sessionKeyRoles - sessionKeyEmail - sessionKeyUsername - sessionRedirectTarget + "git.cacert.org/oidc_login/app/handlers" + "git.cacert.org/oidc_login/app/services" + commonHandlers "git.cacert.org/oidc_login/common/handlers" + commonServices "git.cacert.org/oidc_login/common/services" ) func main() { - if err := k.Load(file.Provider("resourceapp.json"), jsonParser.Parser()); err != nil && !os.IsNotExist(err) { + f := flag.NewFlagSet("config", flag.ContinueOnError) + f.Usage = func() { + fmt.Println(f.FlagUsages()) + os.Exit(0) + } + f.StringSlice("conf", []string{"resource_app.toml"}, "path to one or more .toml files") + logger := log.New() + var err error + + if err = f.Parse(os.Args[1:]); err != nil { + logger.Fatal(err) + } + + config := koanf.New(".") + + _ = config.Load(confmap.Provider(map[string]interface{}{ + "server.port": 4000, + "server.name": "app.cacert.localhost", + "server.key": "certs/app.cacert.localhost.key", + "server.certificate": "certs/app.cacert.localhost.crt.pem", + "oidc.server": "https://auth.cacert.localhost:4444/", + "session.path": "sessions/app", + }, "."), nil) + cFiles, _ := f.GetStringSlice("conf") + for _, c := range cFiles { + if err := config.Load(file.Provider(c), toml.Parser()); err != nil { + logger.Fatalf("error loading config file: %s", err) + } + } + if err := config.Load(posflag.Provider(f, ".", config), nil); err != nil { + logger.Fatalf("error loading configuration: %s", err) + } + if err := config.Load(file.Provider("resource_app.toml"), toml.Parser()); err != nil && !os.IsNotExist(err) { log.Fatalf("error loading config: %v", err) } - const prefix = "RESOURCEAPP_" - if err := k.Load(env.Provider(prefix, ".", func(s string) string { + const prefix = "RESOURCE_APP_" + if err := config.Load(env.Provider(prefix, ".", func(s string) string { return strings.Replace(strings.ToLower( strings.TrimPrefix(s, prefix)), "_", ".", -1) }), nil); err != nil { log.Fatalf("error loading config: %v", err) } - oidcServer := k.MustString("oidc.server") - oidcClientId := k.MustString("oidc.client-id") - oidcClientSecret := k.MustString("oidc.client-secret") + oidcServer := config.MustString("oidc.server") + oidcClientId := config.MustString("oidc.client-id") + oidcClientSecret := config.MustString("oidc.client-secret") - sessionPath := k.MustString("session.path") - sessionAuthKey, err := base64.StdEncoding.DecodeString(k.String("session.auth-key")) + sessionPath := config.MustString("session.path") + sessionAuthKey, err := base64.StdEncoding.DecodeString(config.String("session.auth-key")) if err != nil { log.Fatalf("could not decode session auth key: %s", err) } - sessionEncKey, err := base64.StdEncoding.DecodeString(k.String("session.enc-key")) + sessionEncKey, err := base64.StdEncoding.DecodeString(config.String("session.enc-key")) if err != nil { log.Fatalf("could not decode session encryption key: %s", err) } generated := false if len(sessionAuthKey) != 64 { - sessionAuthKey = generateKey(64) + sessionAuthKey = commonServices.GenerateKey(64) generated = true } if len(sessionEncKey) != 32 { - sessionEncKey = generateKey(32) + sessionEncKey = commonServices.GenerateKey(32) generated = true } if generated { - _ = k.Load(confmap.Provider(map[string]interface{}{ + _ = config.Load(confmap.Provider(map[string]interface{}{ "session.auth-key": sessionAuthKey, "session.enc-key": sessionEncKey, }, "."), nil) - jsonData, err := k.Marshal(jsonParser.Parser()) + tomlData, err := config.Marshal(toml.Parser()) if err != nil { log.Fatalf("could not encode session config") } - log.Infof("put the following in your resourceapp.json:\n%s", string(jsonData)) + log.Infof("put the following in your resource_app.toml:\n%s", string(tomlData)) } - var discoveryResponse OpenIDConfiguration - var discoveryUrl *url.URL - - if discoveryUrl, err = url.Parse(oidcServer); err != nil { - log.Fatalf("could not parse oidc.server parameter value %s: %s", oidcServer, err) - } else { - discoveryUrl.Path = "/.well-known/openid-configuration" - } - apiClient, err := client.TLSClient(client.TLSClientOptions{InsecureSkipVerify: true}) - if err := discoverOidc(discoveryUrl, apiClient, &discoveryResponse); err != nil { + var discoveryResponse *commonServices.OpenIDConfiguration + apiClient := &http.Client{} + if discoveryResponse, err = commonServices.DiscoverOIDC(logger, oidcServer, apiClient); err != nil { log.Fatalf("OpenID Connect discovery failed: %s", err) } oauth2Config := &oauth2.Config{ @@ -122,251 +128,69 @@ func main() { log.Fatalf("could not fetch JWKs: %s", err) } - if _, err = os.Stat(sessionPath); err != nil { - if os.IsNotExist(err) { - if err = os.MkdirAll(sessionPath, 0700); err != nil { - log.Fatalf("could not create session store directory: %s", err) - } + services.InitSessionStore(logger, sessionPath, sessionAuthKey, sessionEncKey) + + authMiddleware := handlers.Authenticate(oauth2Config, config.MustString("oidc.client-id")) + + serverAddr := fmt.Sprintf("%s:%d", config.String("server.name"), config.Int("server.port")) + indexHandler := handlers.NewIndexHandler(discoveryResponse.EndSessionEndpoint, serverAddr) + callbackHandler := handlers.NewCallbackHandler(keySet, oauth2Config) + afterLogoutHandler := handlers.NewAfterLogoutHandler(logger) + + router := http.NewServeMux() + router.Handle("/", authMiddleware(indexHandler)) + router.Handle("/callback", callbackHandler) + router.Handle("/after-logout", afterLogoutHandler) + + nextRequestId := func() string { + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + + tracing := commonHandlers.Tracing(nextRequestId) + logging := commonHandlers.Logging(logger) + hsts := commonHandlers.EnableHSTS() + + tlsConfig := &tls.Config{ + ServerName: config.String("server.name"), + MinVersion: tls.VersionTLS12, + } + server := &http.Server{ + Addr: serverAddr, + Handler: tracing(logging(hsts(router))), + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 15 * time.Second, + TLSConfig: tlsConfig, + } + + done := make(chan bool) + quit := make(chan os.Signal, 1) + signal.Notify(quit, os.Interrupt) + + go func() { + <-quit + logger.Infoln("Server is shutting down...") + atomic.StoreInt32(&commonHandlers.Healthy, 0) + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + server.SetKeepAlivesEnabled(false) + if err := server.Shutdown(ctx); err != nil { + logger.Fatalf("Could not gracefully shutdown the server: %v\n", err) } - } - sessionStore = sessions.NewFilesystemStore(sessionPath, sessionAuthKey, sessionEncKey) + close(done) + }() - http.Handle("/", authenticate(oauth2Config)(NewIndexPage(discoveryResponse.EndSessionEndpoint))) - http.Handle("/callback", NewCallbackHandler(keySet, oauth2Config)) - - err = http.ListenAndServe(":4000", http.DefaultServeMux) - if err != nil { - log.Fatal(err) + logger.Infof("Server is ready to handle requests at https://%s/", server.Addr) + atomic.StoreInt32(&commonHandlers.Healthy, 1) + if err := server.ListenAndServeTLS( + config.String("server.certificate"), config.String("server.key"), + ); err != nil && err != http.ErrServerClosed { + logger.Fatalf("Could not listen on %s: %v\n", server.Addr, err) } -} - -func generateKey(length int) []byte { - key := make([]byte, length) - read, err := rand.Read(key) - if err != nil { - log.Fatalf("could not generate key: %s", err) - } - if read != length { - log.Fatalf("read %d bytes, expected %d bytes", read, length) - } - return key -} - -func discoverOidc(discoveryUrl *url.URL, apiClient *http.Client, o *OpenIDConfiguration) error { - var body []byte - req, err := http.NewRequest(http.MethodGet, discoveryUrl.String(), bytes.NewBuffer(body)) - if err != nil { - return err - } - req.Header = map[string][]string{ - "Accept": {"application/json"}, - } - - resp, err := apiClient.Do(req) - if err != nil { - return err - } - - dec := json.NewDecoder(resp.Body) - err = dec.Decode(o) - if err != nil { - return err - } - - return nil -} - -type callbackHandler struct { - keySet *jwk.Set - oauth2Config *oauth2.Config -} - -func (c *callbackHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - if request.Method != http.MethodGet { - http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - return - } - if request.URL.Path != "/callback" { - http.NotFound(writer, request) - return - } - - code := request.URL.Query().Get("code") - - ctx := context.Background() - httpClient, err := client.TLSClient(client.TLSClientOptions{InsecureSkipVerify: true}) - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - - tok, err := c.oauth2Config.Exchange(ctx, code) - if err != nil { - log.Error(err) - http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return - } - - session, err := sessionStore.Get(request, "resource_session") - if err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return - } - - session.Values[sessionKeyAccessToken] = tok.AccessToken - session.Values[sessionKeyRefreshToken] = tok.RefreshToken - session.Values[sessionKeyIdToken] = tok.Extra("id_token").(string) - - idToken := tok.Extra("id_token") - if parsedIdToken, err := jwt.ParseString(idToken.(string), jwt.WithKeySet(c.keySet), jwt.WithOpenIDClaims()); err != nil { - log.Error(err) - http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return - } else { - log.Infof(` -ID Token -======== - -Subject: %s -Audience: %s -Issued at: %s -Issued by: %s -Not valid before: %s -Not valid after: %s - -`, - parsedIdToken.Subject(), - parsedIdToken.Audience(), - parsedIdToken.IssuedAt(), - parsedIdToken.Issuer(), - parsedIdToken.NotBefore(), - parsedIdToken.Expiration(), - ) - - session.Values[sessionKeyUserId] = parsedIdToken.Subject() - - if roles, ok := parsedIdToken.Get("Groups"); ok { - session.Values[sessionKeyRoles] = roles - } - if username, ok := parsedIdToken.Get("Username"); ok { - session.Values[sessionKeyUsername] = username - } - if email, ok := parsedIdToken.Get("Email"); ok { - session.Values[sessionKeyEmail] = email - } - } - if err = session.Save(request, writer); err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - } - if redirectTarget, ok := session.Values[sessionRedirectTarget]; ok { - writer.Header().Set("Location", redirectTarget.(string)) - } else { - writer.Header().Set("Location", "/") - } - - writer.WriteHeader(http.StatusFound) -} - -func NewCallbackHandler(keySet *jwk.Set, oauth2Config *oauth2.Config) *callbackHandler { - return &callbackHandler{keySet: keySet, oauth2Config: oauth2Config} -} - -func authenticate(oauth2Config *oauth2.Config) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session, err := sessionStore.Get(r, "resource_session") - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if _, ok := session.Values[sessionKeyUserId]; ok { - next.ServeHTTP(w, r) - return - } - session.Values[sessionRedirectTarget] = r.URL.String() - if err = session.Save(r, w); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - var authUrl *url.URL - if authUrl, err = url.Parse(oauth2Config.Endpoint.AuthURL); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - queryValues := authUrl.Query() - queryValues.Set("client_id", k.String("oidc.client-id")) - queryValues.Set("response_type", "code") - queryValues.Set("scope", "openid offline") - queryValues.Set("state", base64.URLEncoding.EncodeToString(generateKey(8))) - authUrl.RawQuery = queryValues.Encode() - - w.Header().Set("Location", authUrl.String()) - w.WriteHeader(http.StatusFound) - }) - } -} - -type indexHandler struct { - logoutUrl string -} - -func (i indexHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - if request.Method != http.MethodGet { - http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - return - } - if request.URL.Path != "/" { - http.NotFound(writer, request) - return - } - writer.WriteHeader(http.StatusOK) - - page, err := template.New("").Parse(` - - -Auth test - -

Hello {{ .User }}

-

This is an authorization protected resource

-Logout - - -`) - if err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return - } - session, err := sessionStore.Get(request, "resource_session") - if err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return - } - - logoutUrl, err := url.Parse(i.logoutUrl) - if err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return - } - var user string - var ok bool - if user, ok = session.Values[sessionKeyUsername].(string); ok { - - } - if idToken, ok := session.Values[sessionKeyIdToken].(string); ok { - logoutUrl.RawQuery = url.Values{ - "id_token_hint": []string{idToken}, - "post_logout_redirect_uri": []string{"/logged_out"}, - }.Encode() - } - - writer.Header().Add("Content-Type", "text/html") - err = page.Execute(writer, map[string]interface{}{ - "User": user, - "LogoutURL": logoutUrl.String(), - }) - if err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return - } -} - -func NewIndexPage(logoutUrl string) *indexHandler { - return &indexHandler{logoutUrl: logoutUrl} + + <-done + logger.Infoln("Server stopped") } diff --git a/cmd/idp/main.go b/cmd/idp/main.go index 16ff6f2..2a3bf05 100644 --- a/cmd/idp/main.go +++ b/cmd/idp/main.go @@ -9,15 +9,16 @@ import ( "net/url" "os" "os/signal" + "strings" "sync/atomic" "time" "github.com/go-openapi/runtime/client" "github.com/gorilla/csrf" "github.com/knadh/koanf" - "github.com/knadh/koanf/parsers/json" "github.com/knadh/koanf/parsers/toml" "github.com/knadh/koanf/providers/confmap" + "github.com/knadh/koanf/providers/env" "github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/providers/posflag" hydra "github.com/ory/hydra-client-go/client" @@ -45,12 +46,6 @@ func main() { config := koanf.New(".") - cFiles, _ := f.GetStringSlice("conf") - for _, c := range cFiles { - if err := config.Load(file.Provider(c), toml.Parser()); err != nil { - logger.Fatalf("error loading config file: %s", err) - } - } _ = config.Load(confmap.Provider(map[string]interface{}{ "server.port": 3000, "server.name": "login.cacert.localhost", @@ -58,10 +53,22 @@ func main() { "server.certificate": "certs/idp.cacert.localhost.crt.pem", "admin.url": "https://hydra.cacert.localhost:4445/", }, "."), nil) - _ = config.Load(file.Provider("idp.json"), json.Parser()) + cFiles, _ := f.GetStringSlice("conf") + for _, c := range cFiles { + if err := config.Load(file.Provider(c), toml.Parser()); err != nil { + logger.Fatalf("error loading config file: %s", err) + } + } if err := config.Load(posflag.Provider(f, ".", config), nil); err != nil { logger.Fatalf("error loading configuration: %s", err) } + const prefix = "IDP_" + if err := config.Load(env.Provider(prefix, ".", func(s string) string { + return strings.Replace(strings.ToLower( + strings.TrimPrefix(s, prefix)), "_", ".", -1) + }), nil); err != nil { + log.Fatalf("error loading config: %v", err) + } logger.Infoln("Server is starting") ctx := context.Background() @@ -75,15 +82,21 @@ func main() { adminClient := hydra.New(clientTransport, nil) handlerContext := context.WithValue(ctx, handlers.CtxAdminClient, adminClient.Admin) - loginHandler, err := handlers.NewLoginHandler(handlerContext) + loginHandler, err := handlers.NewLoginHandler(logger, handlerContext) if err != nil { logger.Fatalf("error initializing login handler: %v", err) } - consentHandler := handlers.NewConsentHandler(handlerContext) + consentHandler := handlers.NewConsentHandler(logger, handlerContext) + logoutHandler := handlers.NewLogoutHandler(logger, handlerContext) + logoutSuccessHandler := handlers.NewLogoutSuccessHandler() + errorHandler := handlers.NewErrorHandler() router := http.NewServeMux() router.Handle("/login", loginHandler) router.Handle("/consent", consentHandler) + router.Handle("/logout", logoutHandler) + router.Handle("/error", errorHandler) + router.Handle("/logout-successful", logoutSuccessHandler) router.Handle("/health", commonHandlers.NewHealthHandler()) if err != nil { @@ -94,23 +107,27 @@ func main() { if err != nil { logger.Fatalf("could not parse CSRF key bytes: %v", err) } - handler := csrf.Protect(csrfKey, csrf.Secure(true))(router) nextRequestId := func() string { return fmt.Sprintf("%d", time.Now().UnixNano()) } + tracing := commonHandlers.Tracing(nextRequestId) + logging := commonHandlers.Logging(logger) + hsts := commonHandlers.EnableHSTS() + csrfProtect := csrf.Protect( + csrfKey, + csrf.Secure(true), + csrf.SameSite(csrf.SameSiteStrictMode), + csrf.MaxAge(600)) + tlsConfig := &tls.Config{ ServerName: config.String("server.name"), MinVersion: tls.VersionTLS12, } server := &http.Server{ - Addr: fmt.Sprintf("%s:%d", config.String("server.name"), config.Int("server.port")), - Handler: commonHandlers.Tracing(nextRequestId)( - commonHandlers.Logging(logger)( - commonHandlers.EnableHSTS()(handler), - ), - ), + Addr: fmt.Sprintf("%s:%d", config.String("server.name"), config.Int("server.port")), + Handler: tracing(logging(hsts(csrfProtect(router)))), ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, IdleTimeout: 15 * time.Second, diff --git a/common/services/oidc.go b/common/services/oidc.go new file mode 100644 index 0000000..e2db15d --- /dev/null +++ b/common/services/oidc.go @@ -0,0 +1,51 @@ +package services + +import ( + "bytes" + "encoding/json" + "net/http" + "net/url" + + log "github.com/sirupsen/logrus" +) + +type OpenIDConfiguration struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + JwksUri string `json:"jwks_uri"` + EndSessionEndpoint string `json:"end_session_endpoint"` +} + +func DiscoverOIDC(logger *log.Logger, oidcServer string, apiClient *http.Client) (o *OpenIDConfiguration, err error) { + var discoveryUrl *url.URL + + if discoveryUrl, err = url.Parse(oidcServer); err != nil { + logger.Fatalf("could not parse oidc.server parameter value %s: %s", oidcServer, err) + } else { + discoveryUrl.Path = "/.well-known/openid-configuration" + } + + var body []byte + var req *http.Request + req, err = http.NewRequest(http.MethodGet, discoveryUrl.String(), bytes.NewBuffer(body)) + if err != nil { + return + } + req.Header = map[string][]string{ + "Accept": {"application/json"}, + } + + resp, err := apiClient.Do(req) + if err != nil { + return + } + + dec := json.NewDecoder(resp.Body) + o = &OpenIDConfiguration{} + err = dec.Decode(o) + if err != nil { + return + } + + return +} diff --git a/common/services/security.go b/common/services/security.go new file mode 100644 index 0000000..b4e454e --- /dev/null +++ b/common/services/security.go @@ -0,0 +1,19 @@ +package services + +import ( + "crypto/rand" + + log "github.com/sirupsen/logrus" +) + +func GenerateKey(length int) []byte { + key := make([]byte, length) + read, err := rand.Read(key) + if err != nil { + log.Fatalf("could not generate key: %s", err) + } + if read != length { + log.Fatalf("read %d bytes, expected %d bytes", read, length) + } + return key +} diff --git a/idp/handlers/consent.go b/idp/handlers/consent.go index 874884a..3e78189 100644 --- a/idp/handlers/consent.go +++ b/idp/handlers/consent.go @@ -12,11 +12,12 @@ import ( type consentHandler struct { adminClient *admin.Client + logger *log.Logger } -func (c *consentHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - consentChallenge := request.URL.Query().Get("consent_challenge") - consentRequest, err := c.adminClient.AcceptConsentRequest( +func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + consentChallenge := r.URL.Query().Get("consent_challenge") + consentRequest, err := h.adminClient.AcceptConsentRequest( admin.NewAcceptConsentRequestParams().WithConsentChallenge(consentChallenge).WithBody( &models.AcceptConsentRequest{ GrantAccessTokenAudience: nil, @@ -26,14 +27,15 @@ func (c *consentHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req RememberFor: 86400, }).WithTimeout(time.Second * 10)) if err != nil { - log.Panic(err) + h.logger.Panic(err) } - writer.Header().Add("Location", *consentRequest.GetPayload().RedirectTo) - writer.WriteHeader(http.StatusFound) + w.Header().Add("Location", *consentRequest.GetPayload().RedirectTo) + w.WriteHeader(http.StatusFound) } -func NewConsentHandler(ctx context.Context) *consentHandler { +func NewConsentHandler(logger *log.Logger, ctx context.Context) *consentHandler { return &consentHandler{ + logger: logger, adminClient: ctx.Value(CtxAdminClient).(*admin.Client), } } diff --git a/idp/handlers/error.go b/idp/handlers/error.go new file mode 100644 index 0000000..1c318e7 --- /dev/null +++ b/idp/handlers/error.go @@ -0,0 +1,20 @@ +package handlers + +import ( + "fmt" + "net/http" +) + +type errorHandler struct { +} + +func (e *errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + _, _ = fmt.Fprintf(w, ` +didumm %#v +`, r.URL.Query()) +} + +func NewErrorHandler() *errorHandler { + return &errorHandler{} +} diff --git a/idp/handlers/login.go b/idp/handlers/login.go index e4f65fa..15eb76b 100644 --- a/idp/handlers/login.go +++ b/idp/handlers/login.go @@ -23,6 +23,7 @@ type loginHandler struct { bundle *i18n.Bundle messageCatalog map[string]*i18n.Message adminClient *admin.Client + logger *log.Logger } type LoginInformation struct { @@ -30,27 +31,27 @@ type LoginInformation struct { Password string `form:"password" validate:"required"` } -func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { +func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var err error - challenge := request.URL.Query().Get("login_challenge") - log.Debugf("received challenge %s\n", challenge) + challenge := r.URL.Query().Get("login_challenge") + h.logger.Debugf("received challenge %s\n", challenge) validate := validator.New() - switch request.Method { + switch r.Method { case http.MethodGet: // GET should render login form - err = h.loginTemplate.Lookup("base").Execute(writer, map[string]interface{}{ + err = h.loginTemplate.Lookup("base").Execute(w, map[string]interface{}{ "Title": "Title", - csrf.TemplateTag: csrf.TemplateField(request), + csrf.TemplateTag: csrf.TemplateField(r), "LabelEmail": "Email", "LabelPassword": "Password", "LabelLogin": "Login", "errors": map[string]string{}, }) if err != nil { - log.Error(err) - http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + h.logger.Error(err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } break @@ -60,23 +61,23 @@ func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reque // validate input decoder := form.NewDecoder() - err = decoder.Decode(&loginInfo, request.Form) + err = decoder.Decode(&loginInfo, r.Form) if err != nil { - log.Error(err) - http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + h.logger.Error(err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } err := validate.Struct(&loginInfo) if err != nil { errors := make(map[string]string) for _, err := range err.(validator.ValidationErrors) { - accept := request.Header.Get("Accept-Language") + accept := r.Header.Get("Accept-Language") errors[err.Field()] = h.lookupErrorMessage(err.Tag(), err.Field(), err.Value(), i18n.NewLocalizer(h.bundle, accept)) } - err = h.loginTemplate.Lookup("base").Execute(writer, map[string]interface{}{ + err = h.loginTemplate.Lookup("base").Execute(w, map[string]interface{}{ "Title": "Title", - csrf.TemplateTag: csrf.TemplateField(request), + csrf.TemplateTag: csrf.TemplateField(r), "LabelEmail": "Email", "LabelPassword": "Password", "LabelLogin": "Login", @@ -84,8 +85,8 @@ func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reque "errors": errors, }) if err != nil { - log.Error(err) - http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + h.logger.Error(err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } return @@ -103,13 +104,14 @@ func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reque Subject: &subject, }).WithTimeout(time.Second * 10)) if err != nil { - log.Panic(err) + h.logger.Errorf("error getting logout requests: %v", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } - writer.Header().Add("Location", *loginRequest.GetPayload().RedirectTo) - writer.WriteHeader(http.StatusFound) + w.Header().Add("Location", *loginRequest.GetPayload().RedirectTo) + w.WriteHeader(http.StatusFound) break default: - http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } } @@ -118,13 +120,13 @@ func (h *loginHandler) lookupErrorMessage(tag string, field string, value interf var message *i18n.Message message, ok := h.messageCatalog[fmt.Sprintf("%s-%s", field, tag)] if !ok { - log.Infof("no specific error message %s-%s", field, tag) + h.logger.Infof("no specific error message %s-%s", field, tag) message, ok = h.messageCatalog[tag] if !ok { - log.Infof("no specific error message %s", tag) + h.logger.Infof("no specific error message %s", tag) message, ok = h.messageCatalog["unknown"] if !ok { - log.Error("no default translation found") + h.logger.Error("no default translation found") return tag } } @@ -137,18 +139,19 @@ func (h *loginHandler) lookupErrorMessage(tag string, field string, value interf }, }) if err != nil { - log.Error(err) + h.logger.Error(err) return tag } return translation } -func NewLoginHandler(ctx context.Context) (*loginHandler, error) { +func NewLoginHandler(logger *log.Logger, ctx context.Context) (*loginHandler, error) { loginTemplate, err := template.ParseFiles("templates/base.html", "templates/login.html") if err != nil { return nil, err } return &loginHandler{ + logger: logger, loginTemplate: loginTemplate, bundle: ctx.Value(services.CtxI18nBundle).(*i18n.Bundle), messageCatalog: ctx.Value(services.CtxI18nCatalog).(map[string]*i18n.Message), diff --git a/idp/handlers/logout.go b/idp/handlers/logout.go new file mode 100644 index 0000000..445a569 --- /dev/null +++ b/idp/handlers/logout.go @@ -0,0 +1,58 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "github.com/ory/hydra-client-go/client/admin" + log "github.com/sirupsen/logrus" +) + +type logoutHandler struct { + adminClient *admin.Client + logger *log.Logger +} + +func (h *logoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + challenge := r.URL.Query().Get("logout_challenge") + h.logger.Debugf("received challenge %s\n", challenge) + + logoutRequest, err := h.adminClient.GetLogoutRequest( + admin.NewGetLogoutRequestParams().WithLogoutChallenge(challenge).WithTimeout(time.Second * 10)) + if err != nil { + h.logger.Errorf("error getting logout requests: %v", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + h.logger.Debugf("received logout request: %#v", logoutRequest.Payload) + + acceptLogoutRequest, err := h.adminClient.AcceptLogoutRequest( + admin.NewAcceptLogoutRequestParams().WithLogoutChallenge(challenge)) + if err != nil { + h.logger.Errorf("error accepting logout: %v", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } + + w.Header().Set("Location", *acceptLogoutRequest.GetPayload().RedirectTo) + w.WriteHeader(http.StatusFound) +} + +func NewLogoutHandler(logger *log.Logger, ctx context.Context) *logoutHandler { + return &logoutHandler{ + logger: logger, + adminClient: ctx.Value(CtxAdminClient).(*admin.Client), + } +} + +type logoutSuccessHandler struct { +} + +func (l *logoutSuccessHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + panic("implement me") +} + +func NewLogoutSuccessHandler() *logoutSuccessHandler { + return &logoutSuccessHandler{} +}