package handlers import ( "context" "net/http" "github.com/go-openapi/runtime/client" "github.com/lestrrat-go/jwx/jwk" log "github.com/sirupsen/logrus" "golang.org/x/oauth2" "git.cacert.org/oidc_login/app/services" commonServices "git.cacert.org/oidc_login/common/services" ) const ( sessionKeyAccessToken = iota sessionKeyRefreshToken sessionKeyIdToken sessionRedirectTarget ) type oidcCallbackHandler struct { keySet *jwk.Set logger *log.Logger oauth2Config *oauth2.Config } func (c *oidcCallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } if r.URL.Path != "/callback" { http.NotFound(w, r) return } errorText := r.URL.Query().Get("error") errorDescription := r.URL.Query().Get("error_description") if errorText != "" { c.RenderErrorTemplate(w, errorText, errorDescription, http.StatusForbidden) return } code := r.URL.Query().Get("code") ctx := context.Background() httpClient, err := client.TLSClient(client.TLSClientOptions{InsecureSkipVerify: true}) ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) tok, err := c.oauth2Config.Exchange(ctx, code) if err != nil { c.logger.Error(err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } session, err := services.GetSessionStore().Get(r, "resource_session") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } session.Values[sessionKeyAccessToken] = tok.AccessToken session.Values[sessionKeyRefreshToken] = tok.RefreshToken idToken := tok.Extra("id_token").(string) session.Values[sessionKeyIdToken] = idToken if oidcToken, err := ParseIdToken(idToken, c.keySet); err != nil { c.logger.Error(err) http.Error(w, err.Error(), http.StatusInternalServerError) return } else { c.logger.Debugf(` ID Token ======== Subject: %s Audience: %s Issued at: %s Issued by: %s Not valid before: %s Not valid after: %s `, oidcToken.Subject(), oidcToken.Audience(), oidcToken.IssuedAt(), oidcToken.Issuer(), oidcToken.NotBefore(), oidcToken.Expiration(), ) } if err = session.Save(r, w); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } if redirectTarget, ok := session.Values[sessionRedirectTarget]; ok { w.Header().Set("Location", redirectTarget.(string)) } else { w.Header().Set("Location", "/") } w.WriteHeader(http.StatusFound) } func (c *oidcCallbackHandler) RenderErrorTemplate(w http.ResponseWriter, errorText string, errorDescription string, status int) { if errorDescription != "" { http.Error(w, errorDescription, status) } else { http.Error(w, errorText, status) } } func NewCallbackHandler(ctx context.Context, logger *log.Logger) *oidcCallbackHandler { return &oidcCallbackHandler{ keySet: commonServices.GetJwkSet(ctx), logger: logger, oauth2Config: commonServices.GetOAuth2Config(ctx), } }