Refactor app, implement logout

This commit is contained in:
Jan Dittberner 2020-12-31 13:19:21 +01:00
parent ce1fac0e68
commit 27e225795c
14 changed files with 647 additions and 349 deletions

3
.gitignore vendored
View file

@ -1,3 +1,4 @@
/.idea/
/*.toml /*.toml
/.idea/
/certs/ /certs/
/sessions/

View file

@ -0,0 +1,33 @@
package handlers
import (
"net/http"
"github.com/sirupsen/logrus"
"git.cacert.org/oidc_login/app/services"
)
type afterLogoutHandler struct {
logger *logrus.Logger
}
func (h *afterLogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
session, err := services.GetSessionStore().Get(r, sessionName)
if err != nil {
h.logger.Errorf("could not get session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session.Options.MaxAge = -1
if err = session.Save(r, w); err != nil {
h.logger.Errorf("could not save session: %v", err)
}
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusFound)
}
func NewAfterLogoutHandler(logger *logrus.Logger) *afterLogoutHandler {
return &afterLogoutHandler{logger: logger}
}

49
app/handlers/common.go Normal file
View file

@ -0,0 +1,49 @@
package handlers
import (
"encoding/base64"
"net/http"
"net/url"
"golang.org/x/oauth2"
"git.cacert.org/oidc_login/app/services"
commonServices "git.cacert.org/oidc_login/common/services"
)
const sessionName = "resource_session"
func Authenticate(oauth2Config *oauth2.Config, clientId string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := services.GetSessionStore().Get(r, sessionName)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if _, ok := session.Values[sessionKeyUserId]; ok {
next.ServeHTTP(w, r)
return
}
session.Values[sessionRedirectTarget] = r.URL.String()
if err = session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var authUrl *url.URL
if authUrl, err = url.Parse(oauth2Config.Endpoint.AuthURL); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
queryValues := authUrl.Query()
queryValues.Set("client_id", clientId)
queryValues.Set("response_type", "code")
queryValues.Set("scope", "openid offline_access profile email")
queryValues.Set("state", base64.URLEncoding.EncodeToString(commonServices.GenerateKey(8)))
authUrl.RawQuery = queryValues.Encode()
w.Header().Set("Location", authUrl.String())
w.WriteHeader(http.StatusFound)
})
}
}

79
app/handlers/index.go Normal file
View file

@ -0,0 +1,79 @@
package handlers
import (
"fmt"
"html/template"
"net/http"
"net/url"
"git.cacert.org/oidc_login/app/services"
)
type indexHandler struct {
logoutUrl string
serverAddr string
}
func (h *indexHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if request.Method != http.MethodGet {
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
if request.URL.Path != "/" {
http.NotFound(writer, request)
return
}
writer.WriteHeader(http.StatusOK)
page, err := template.New("").Parse(`
<!DOCTYPE html>
<html lang="en">
<head><title>Auth test</title></head>
<body>
<h1>Hello {{ .User }}</h1>
<p>This is an authorization protected resource</p>
<a href="{{ .LogoutURL }}">Logout</a>
</body>
</html>
`)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
session, err := services.GetSessionStore().Get(request, sessionName)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
logoutUrl, err := url.Parse(h.logoutUrl)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
var user string
var ok bool
if user, ok = session.Values[sessionKeyUsername].(string); ok {
}
if idToken, ok := session.Values[sessionKeyIdToken].(string); ok {
logoutUrl.RawQuery = url.Values{
"id_token_hint": []string{idToken},
"post_logout_redirect_uri": []string{fmt.Sprintf("https://%s/after-logout", h.serverAddr)},
}.Encode()
}
writer.Header().Add("Content-Type", "text/html")
err = page.Execute(writer, map[string]interface{}{
"User": user,
"LogoutURL": logoutUrl.String(),
})
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
}
func NewIndexHandler(logoutUrl string, serverAddr string) *indexHandler {
return &indexHandler{logoutUrl: logoutUrl, serverAddr: serverAddr}
}

View file

@ -0,0 +1,117 @@
package handlers
import (
"context"
"net/http"
"github.com/go-openapi/runtime/client"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"git.cacert.org/oidc_login/app/services"
)
const (
sessionKeyAccessToken = iota
sessionKeyRefreshToken
sessionKeyIdToken
sessionKeyUserId
sessionKeyRoles
sessionKeyEmail
sessionKeyUsername
sessionRedirectTarget
)
type oidcCallbackHandler struct {
keySet *jwk.Set
oauth2Config *oauth2.Config
}
func (c *oidcCallbackHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if request.Method != http.MethodGet {
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
if request.URL.Path != "/callback" {
http.NotFound(writer, request)
return
}
code := request.URL.Query().Get("code")
ctx := context.Background()
httpClient, err := client.TLSClient(client.TLSClientOptions{InsecureSkipVerify: true})
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
tok, err := c.oauth2Config.Exchange(ctx, code)
if err != nil {
logrus.Error(err)
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
session, err := services.GetSessionStore().Get(request, "resource_session")
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
session.Values[sessionKeyAccessToken] = tok.AccessToken
session.Values[sessionKeyRefreshToken] = tok.RefreshToken
session.Values[sessionKeyIdToken] = tok.Extra("id_token").(string)
idToken := tok.Extra("id_token")
if parsedIdToken, err := jwt.ParseString(idToken.(string), jwt.WithKeySet(c.keySet), jwt.WithOpenIDClaims()); err != nil {
logrus.Error(err)
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
} else {
logrus.Infof(`
ID Token
========
Subject: %s
Audience: %s
Issued at: %s
Issued by: %s
Not valid before: %s
Not valid after: %s
`,
parsedIdToken.Subject(),
parsedIdToken.Audience(),
parsedIdToken.IssuedAt(),
parsedIdToken.Issuer(),
parsedIdToken.NotBefore(),
parsedIdToken.Expiration(),
)
session.Values[sessionKeyUserId] = parsedIdToken.Subject()
if roles, ok := parsedIdToken.Get("Groups"); ok {
session.Values[sessionKeyRoles] = roles
}
if username, ok := parsedIdToken.Get("Username"); ok {
session.Values[sessionKeyUsername] = username
}
if email, ok := parsedIdToken.Get("Email"); ok {
session.Values[sessionKeyEmail] = email
}
}
if err = session.Save(request, writer); err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
}
if redirectTarget, ok := session.Values[sessionRedirectTarget]; ok {
writer.Header().Set("Location", redirectTarget.(string))
} else {
writer.Header().Set("Location", "/")
}
writer.WriteHeader(http.StatusFound)
}
func NewCallbackHandler(keySet *jwk.Set, oauth2Config *oauth2.Config) *oidcCallbackHandler {
return &oidcCallbackHandler{keySet: keySet, oauth2Config: oauth2Config}
}

25
app/services/session.go Normal file
View file

@ -0,0 +1,25 @@
package services
import (
"os"
"github.com/gorilla/sessions"
log "github.com/sirupsen/logrus"
)
var store *sessions.FilesystemStore
func InitSessionStore(logger *log.Logger, sessionPath string, keys ...[]byte) {
store = sessions.NewFilesystemStore(sessionPath, keys...)
if _, err := os.Stat(sessionPath); err != nil {
if os.IsNotExist(err) {
if err = os.MkdirAll(sessionPath, 0700); err != nil {
logger.Fatalf("could not create session store directory: %s", err)
}
}
}
}
func GetSessionStore() *sessions.FilesystemStore {
return store
}

View file

@ -1,111 +1,117 @@
package main package main
import ( import (
"bytes"
"context" "context"
"crypto/rand" "crypto/tls"
"encoding/base64" "encoding/base64"
"encoding/json" "fmt"
"html/template"
"net/http" "net/http"
"net/url"
"os" "os"
"os/signal"
"strings" "strings"
"sync/atomic"
"time"
"github.com/go-openapi/runtime/client"
"github.com/gorilla/sessions"
"github.com/knadh/koanf" "github.com/knadh/koanf"
jsonParser "github.com/knadh/koanf/parsers/json" "github.com/knadh/koanf/parsers/toml"
"github.com/knadh/koanf/providers/confmap" "github.com/knadh/koanf/providers/confmap"
"github.com/knadh/koanf/providers/env" "github.com/knadh/koanf/providers/env"
"github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/providers/file"
"github.com/knadh/koanf/providers/posflag"
"github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
flag "github.com/spf13/pflag"
"golang.org/x/oauth2" "golang.org/x/oauth2"
)
type OpenIDConfiguration struct { "git.cacert.org/oidc_login/app/handlers"
AuthorizationEndpoint string `json:"authorization_endpoint"` "git.cacert.org/oidc_login/app/services"
TokenEndpoint string `json:"token_endpoint"` commonHandlers "git.cacert.org/oidc_login/common/handlers"
JwksUri string `json:"jwks_uri"` commonServices "git.cacert.org/oidc_login/common/services"
EndSessionEndpoint string `json:"end_session_endpoint"`
}
var (
sessionStore *sessions.FilesystemStore
k = koanf.New(".")
)
const (
sessionKeyAccessToken = iota
sessionKeyRefreshToken
sessionKeyIdToken
sessionKeyUserId
sessionKeyRoles
sessionKeyEmail
sessionKeyUsername
sessionRedirectTarget
) )
func main() { func main() {
if err := k.Load(file.Provider("resourceapp.json"), jsonParser.Parser()); err != nil && !os.IsNotExist(err) { f := flag.NewFlagSet("config", flag.ContinueOnError)
f.Usage = func() {
fmt.Println(f.FlagUsages())
os.Exit(0)
}
f.StringSlice("conf", []string{"resource_app.toml"}, "path to one or more .toml files")
logger := log.New()
var err error
if err = f.Parse(os.Args[1:]); err != nil {
logger.Fatal(err)
}
config := koanf.New(".")
_ = config.Load(confmap.Provider(map[string]interface{}{
"server.port": 4000,
"server.name": "app.cacert.localhost",
"server.key": "certs/app.cacert.localhost.key",
"server.certificate": "certs/app.cacert.localhost.crt.pem",
"oidc.server": "https://auth.cacert.localhost:4444/",
"session.path": "sessions/app",
}, "."), nil)
cFiles, _ := f.GetStringSlice("conf")
for _, c := range cFiles {
if err := config.Load(file.Provider(c), toml.Parser()); err != nil {
logger.Fatalf("error loading config file: %s", err)
}
}
if err := config.Load(posflag.Provider(f, ".", config), nil); err != nil {
logger.Fatalf("error loading configuration: %s", err)
}
if err := config.Load(file.Provider("resource_app.toml"), toml.Parser()); err != nil && !os.IsNotExist(err) {
log.Fatalf("error loading config: %v", err) log.Fatalf("error loading config: %v", err)
} }
const prefix = "RESOURCEAPP_" const prefix = "RESOURCE_APP_"
if err := k.Load(env.Provider(prefix, ".", func(s string) string { if err := config.Load(env.Provider(prefix, ".", func(s string) string {
return strings.Replace(strings.ToLower( return strings.Replace(strings.ToLower(
strings.TrimPrefix(s, prefix)), "_", ".", -1) strings.TrimPrefix(s, prefix)), "_", ".", -1)
}), nil); err != nil { }), nil); err != nil {
log.Fatalf("error loading config: %v", err) log.Fatalf("error loading config: %v", err)
} }
oidcServer := k.MustString("oidc.server") oidcServer := config.MustString("oidc.server")
oidcClientId := k.MustString("oidc.client-id") oidcClientId := config.MustString("oidc.client-id")
oidcClientSecret := k.MustString("oidc.client-secret") oidcClientSecret := config.MustString("oidc.client-secret")
sessionPath := k.MustString("session.path") sessionPath := config.MustString("session.path")
sessionAuthKey, err := base64.StdEncoding.DecodeString(k.String("session.auth-key")) sessionAuthKey, err := base64.StdEncoding.DecodeString(config.String("session.auth-key"))
if err != nil { if err != nil {
log.Fatalf("could not decode session auth key: %s", err) log.Fatalf("could not decode session auth key: %s", err)
} }
sessionEncKey, err := base64.StdEncoding.DecodeString(k.String("session.enc-key")) sessionEncKey, err := base64.StdEncoding.DecodeString(config.String("session.enc-key"))
if err != nil { if err != nil {
log.Fatalf("could not decode session encryption key: %s", err) log.Fatalf("could not decode session encryption key: %s", err)
} }
generated := false generated := false
if len(sessionAuthKey) != 64 { if len(sessionAuthKey) != 64 {
sessionAuthKey = generateKey(64) sessionAuthKey = commonServices.GenerateKey(64)
generated = true generated = true
} }
if len(sessionEncKey) != 32 { if len(sessionEncKey) != 32 {
sessionEncKey = generateKey(32) sessionEncKey = commonServices.GenerateKey(32)
generated = true generated = true
} }
if generated { if generated {
_ = k.Load(confmap.Provider(map[string]interface{}{ _ = config.Load(confmap.Provider(map[string]interface{}{
"session.auth-key": sessionAuthKey, "session.auth-key": sessionAuthKey,
"session.enc-key": sessionEncKey, "session.enc-key": sessionEncKey,
}, "."), nil) }, "."), nil)
jsonData, err := k.Marshal(jsonParser.Parser()) tomlData, err := config.Marshal(toml.Parser())
if err != nil { if err != nil {
log.Fatalf("could not encode session config") log.Fatalf("could not encode session config")
} }
log.Infof("put the following in your resourceapp.json:\n%s", string(jsonData)) log.Infof("put the following in your resource_app.toml:\n%s", string(tomlData))
} }
var discoveryResponse OpenIDConfiguration var discoveryResponse *commonServices.OpenIDConfiguration
var discoveryUrl *url.URL apiClient := &http.Client{}
if discoveryResponse, err = commonServices.DiscoverOIDC(logger, oidcServer, apiClient); err != nil {
if discoveryUrl, err = url.Parse(oidcServer); err != nil {
log.Fatalf("could not parse oidc.server parameter value %s: %s", oidcServer, err)
} else {
discoveryUrl.Path = "/.well-known/openid-configuration"
}
apiClient, err := client.TLSClient(client.TLSClientOptions{InsecureSkipVerify: true})
if err := discoverOidc(discoveryUrl, apiClient, &discoveryResponse); err != nil {
log.Fatalf("OpenID Connect discovery failed: %s", err) log.Fatalf("OpenID Connect discovery failed: %s", err)
} }
oauth2Config := &oauth2.Config{ oauth2Config := &oauth2.Config{
@ -122,251 +128,69 @@ func main() {
log.Fatalf("could not fetch JWKs: %s", err) log.Fatalf("could not fetch JWKs: %s", err)
} }
if _, err = os.Stat(sessionPath); err != nil { services.InitSessionStore(logger, sessionPath, sessionAuthKey, sessionEncKey)
if os.IsNotExist(err) {
if err = os.MkdirAll(sessionPath, 0700); err != nil {
log.Fatalf("could not create session store directory: %s", err)
}
}
}
sessionStore = sessions.NewFilesystemStore(sessionPath, sessionAuthKey, sessionEncKey)
http.Handle("/", authenticate(oauth2Config)(NewIndexPage(discoveryResponse.EndSessionEndpoint))) authMiddleware := handlers.Authenticate(oauth2Config, config.MustString("oidc.client-id"))
http.Handle("/callback", NewCallbackHandler(keySet, oauth2Config))
err = http.ListenAndServe(":4000", http.DefaultServeMux) serverAddr := fmt.Sprintf("%s:%d", config.String("server.name"), config.Int("server.port"))
if err != nil { indexHandler := handlers.NewIndexHandler(discoveryResponse.EndSessionEndpoint, serverAddr)
log.Fatal(err) callbackHandler := handlers.NewCallbackHandler(keySet, oauth2Config)
} afterLogoutHandler := handlers.NewAfterLogoutHandler(logger)
}
func generateKey(length int) []byte { router := http.NewServeMux()
key := make([]byte, length) router.Handle("/", authMiddleware(indexHandler))
read, err := rand.Read(key) router.Handle("/callback", callbackHandler)
if err != nil { router.Handle("/after-logout", afterLogoutHandler)
log.Fatalf("could not generate key: %s", err)
}
if read != length {
log.Fatalf("read %d bytes, expected %d bytes", read, length)
}
return key
}
func discoverOidc(discoveryUrl *url.URL, apiClient *http.Client, o *OpenIDConfiguration) error { nextRequestId := func() string {
var body []byte return fmt.Sprintf("%d", time.Now().UnixNano())
req, err := http.NewRequest(http.MethodGet, discoveryUrl.String(), bytes.NewBuffer(body))
if err != nil {
return err
}
req.Header = map[string][]string{
"Accept": {"application/json"},
} }
resp, err := apiClient.Do(req) tracing := commonHandlers.Tracing(nextRequestId)
if err != nil { logging := commonHandlers.Logging(logger)
return err hsts := commonHandlers.EnableHSTS()
tlsConfig := &tls.Config{
ServerName: config.String("server.name"),
MinVersion: tls.VersionTLS12,
}
server := &http.Server{
Addr: serverAddr,
Handler: tracing(logging(hsts(router))),
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 15 * time.Second,
TLSConfig: tlsConfig,
} }
dec := json.NewDecoder(resp.Body) done := make(chan bool)
err = dec.Decode(o) quit := make(chan os.Signal, 1)
if err != nil { signal.Notify(quit, os.Interrupt)
return err
}
return nil go func() {
} <-quit
logger.Infoln("Server is shutting down...")
type callbackHandler struct { atomic.StoreInt32(&commonHandlers.Healthy, 0)
keySet *jwk.Set
oauth2Config *oauth2.Config
}
func (c *callbackHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if request.Method != http.MethodGet {
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
if request.URL.Path != "/callback" {
http.NotFound(writer, request)
return
}
code := request.URL.Query().Get("code")
ctx := context.Background() ctx := context.Background()
httpClient, err := client.TLSClient(client.TLSClientOptions{InsecureSkipVerify: true}) ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) defer cancel()
tok, err := c.oauth2Config.Exchange(ctx, code) server.SetKeepAlivesEnabled(false)
if err != nil { if err := server.Shutdown(ctx); err != nil {
log.Error(err) logger.Fatalf("Could not gracefully shutdown the server: %v\n", err)
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }
return close(done)
}()
logger.Infof("Server is ready to handle requests at https://%s/", server.Addr)
atomic.StoreInt32(&commonHandlers.Healthy, 1)
if err := server.ListenAndServeTLS(
config.String("server.certificate"), config.String("server.key"),
); err != nil && err != http.ErrServerClosed {
logger.Fatalf("Could not listen on %s: %v\n", server.Addr, err)
} }
session, err := sessionStore.Get(request, "resource_session") <-done
if err != nil { logger.Infoln("Server stopped")
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
session.Values[sessionKeyAccessToken] = tok.AccessToken
session.Values[sessionKeyRefreshToken] = tok.RefreshToken
session.Values[sessionKeyIdToken] = tok.Extra("id_token").(string)
idToken := tok.Extra("id_token")
if parsedIdToken, err := jwt.ParseString(idToken.(string), jwt.WithKeySet(c.keySet), jwt.WithOpenIDClaims()); err != nil {
log.Error(err)
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
} else {
log.Infof(`
ID Token
========
Subject: %s
Audience: %s
Issued at: %s
Issued by: %s
Not valid before: %s
Not valid after: %s
`,
parsedIdToken.Subject(),
parsedIdToken.Audience(),
parsedIdToken.IssuedAt(),
parsedIdToken.Issuer(),
parsedIdToken.NotBefore(),
parsedIdToken.Expiration(),
)
session.Values[sessionKeyUserId] = parsedIdToken.Subject()
if roles, ok := parsedIdToken.Get("Groups"); ok {
session.Values[sessionKeyRoles] = roles
}
if username, ok := parsedIdToken.Get("Username"); ok {
session.Values[sessionKeyUsername] = username
}
if email, ok := parsedIdToken.Get("Email"); ok {
session.Values[sessionKeyEmail] = email
}
}
if err = session.Save(request, writer); err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
}
if redirectTarget, ok := session.Values[sessionRedirectTarget]; ok {
writer.Header().Set("Location", redirectTarget.(string))
} else {
writer.Header().Set("Location", "/")
}
writer.WriteHeader(http.StatusFound)
}
func NewCallbackHandler(keySet *jwk.Set, oauth2Config *oauth2.Config) *callbackHandler {
return &callbackHandler{keySet: keySet, oauth2Config: oauth2Config}
}
func authenticate(oauth2Config *oauth2.Config) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := sessionStore.Get(r, "resource_session")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if _, ok := session.Values[sessionKeyUserId]; ok {
next.ServeHTTP(w, r)
return
}
session.Values[sessionRedirectTarget] = r.URL.String()
if err = session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var authUrl *url.URL
if authUrl, err = url.Parse(oauth2Config.Endpoint.AuthURL); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
queryValues := authUrl.Query()
queryValues.Set("client_id", k.String("oidc.client-id"))
queryValues.Set("response_type", "code")
queryValues.Set("scope", "openid offline")
queryValues.Set("state", base64.URLEncoding.EncodeToString(generateKey(8)))
authUrl.RawQuery = queryValues.Encode()
w.Header().Set("Location", authUrl.String())
w.WriteHeader(http.StatusFound)
})
}
}
type indexHandler struct {
logoutUrl string
}
func (i indexHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if request.Method != http.MethodGet {
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
if request.URL.Path != "/" {
http.NotFound(writer, request)
return
}
writer.WriteHeader(http.StatusOK)
page, err := template.New("").Parse(`
<!DOCTYPE html>
<html lang="en">
<head><title>Auth test</title></head>
<body>
<h1>Hello {{ .User }}</h1>
<p>This is an authorization protected resource</p>
<a href="{{ .LogoutURL }}">Logout</a>
</body>
</html>
`)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
session, err := sessionStore.Get(request, "resource_session")
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
logoutUrl, err := url.Parse(i.logoutUrl)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
var user string
var ok bool
if user, ok = session.Values[sessionKeyUsername].(string); ok {
}
if idToken, ok := session.Values[sessionKeyIdToken].(string); ok {
logoutUrl.RawQuery = url.Values{
"id_token_hint": []string{idToken},
"post_logout_redirect_uri": []string{"/logged_out"},
}.Encode()
}
writer.Header().Add("Content-Type", "text/html")
err = page.Execute(writer, map[string]interface{}{
"User": user,
"LogoutURL": logoutUrl.String(),
})
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
}
func NewIndexPage(logoutUrl string) *indexHandler {
return &indexHandler{logoutUrl: logoutUrl}
} }

View file

@ -9,15 +9,16 @@ import (
"net/url" "net/url"
"os" "os"
"os/signal" "os/signal"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/go-openapi/runtime/client" "github.com/go-openapi/runtime/client"
"github.com/gorilla/csrf" "github.com/gorilla/csrf"
"github.com/knadh/koanf" "github.com/knadh/koanf"
"github.com/knadh/koanf/parsers/json"
"github.com/knadh/koanf/parsers/toml" "github.com/knadh/koanf/parsers/toml"
"github.com/knadh/koanf/providers/confmap" "github.com/knadh/koanf/providers/confmap"
"github.com/knadh/koanf/providers/env"
"github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/providers/file"
"github.com/knadh/koanf/providers/posflag" "github.com/knadh/koanf/providers/posflag"
hydra "github.com/ory/hydra-client-go/client" hydra "github.com/ory/hydra-client-go/client"
@ -45,12 +46,6 @@ func main() {
config := koanf.New(".") config := koanf.New(".")
cFiles, _ := f.GetStringSlice("conf")
for _, c := range cFiles {
if err := config.Load(file.Provider(c), toml.Parser()); err != nil {
logger.Fatalf("error loading config file: %s", err)
}
}
_ = config.Load(confmap.Provider(map[string]interface{}{ _ = config.Load(confmap.Provider(map[string]interface{}{
"server.port": 3000, "server.port": 3000,
"server.name": "login.cacert.localhost", "server.name": "login.cacert.localhost",
@ -58,10 +53,22 @@ func main() {
"server.certificate": "certs/idp.cacert.localhost.crt.pem", "server.certificate": "certs/idp.cacert.localhost.crt.pem",
"admin.url": "https://hydra.cacert.localhost:4445/", "admin.url": "https://hydra.cacert.localhost:4445/",
}, "."), nil) }, "."), nil)
_ = config.Load(file.Provider("idp.json"), json.Parser()) cFiles, _ := f.GetStringSlice("conf")
for _, c := range cFiles {
if err := config.Load(file.Provider(c), toml.Parser()); err != nil {
logger.Fatalf("error loading config file: %s", err)
}
}
if err := config.Load(posflag.Provider(f, ".", config), nil); err != nil { if err := config.Load(posflag.Provider(f, ".", config), nil); err != nil {
logger.Fatalf("error loading configuration: %s", err) logger.Fatalf("error loading configuration: %s", err)
} }
const prefix = "IDP_"
if err := config.Load(env.Provider(prefix, ".", func(s string) string {
return strings.Replace(strings.ToLower(
strings.TrimPrefix(s, prefix)), "_", ".", -1)
}), nil); err != nil {
log.Fatalf("error loading config: %v", err)
}
logger.Infoln("Server is starting") logger.Infoln("Server is starting")
ctx := context.Background() ctx := context.Background()
@ -75,15 +82,21 @@ func main() {
adminClient := hydra.New(clientTransport, nil) adminClient := hydra.New(clientTransport, nil)
handlerContext := context.WithValue(ctx, handlers.CtxAdminClient, adminClient.Admin) handlerContext := context.WithValue(ctx, handlers.CtxAdminClient, adminClient.Admin)
loginHandler, err := handlers.NewLoginHandler(handlerContext) loginHandler, err := handlers.NewLoginHandler(logger, handlerContext)
if err != nil { if err != nil {
logger.Fatalf("error initializing login handler: %v", err) logger.Fatalf("error initializing login handler: %v", err)
} }
consentHandler := handlers.NewConsentHandler(handlerContext) consentHandler := handlers.NewConsentHandler(logger, handlerContext)
logoutHandler := handlers.NewLogoutHandler(logger, handlerContext)
logoutSuccessHandler := handlers.NewLogoutSuccessHandler()
errorHandler := handlers.NewErrorHandler()
router := http.NewServeMux() router := http.NewServeMux()
router.Handle("/login", loginHandler) router.Handle("/login", loginHandler)
router.Handle("/consent", consentHandler) router.Handle("/consent", consentHandler)
router.Handle("/logout", logoutHandler)
router.Handle("/error", errorHandler)
router.Handle("/logout-successful", logoutSuccessHandler)
router.Handle("/health", commonHandlers.NewHealthHandler()) router.Handle("/health", commonHandlers.NewHealthHandler())
if err != nil { if err != nil {
@ -94,23 +107,27 @@ func main() {
if err != nil { if err != nil {
logger.Fatalf("could not parse CSRF key bytes: %v", err) logger.Fatalf("could not parse CSRF key bytes: %v", err)
} }
handler := csrf.Protect(csrfKey, csrf.Secure(true))(router)
nextRequestId := func() string { nextRequestId := func() string {
return fmt.Sprintf("%d", time.Now().UnixNano()) return fmt.Sprintf("%d", time.Now().UnixNano())
} }
tracing := commonHandlers.Tracing(nextRequestId)
logging := commonHandlers.Logging(logger)
hsts := commonHandlers.EnableHSTS()
csrfProtect := csrf.Protect(
csrfKey,
csrf.Secure(true),
csrf.SameSite(csrf.SameSiteStrictMode),
csrf.MaxAge(600))
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
ServerName: config.String("server.name"), ServerName: config.String("server.name"),
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
server := &http.Server{ server := &http.Server{
Addr: fmt.Sprintf("%s:%d", config.String("server.name"), config.Int("server.port")), Addr: fmt.Sprintf("%s:%d", config.String("server.name"), config.Int("server.port")),
Handler: commonHandlers.Tracing(nextRequestId)( Handler: tracing(logging(hsts(csrfProtect(router)))),
commonHandlers.Logging(logger)(
commonHandlers.EnableHSTS()(handler),
),
),
ReadTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second,
IdleTimeout: 15 * time.Second, IdleTimeout: 15 * time.Second,

51
common/services/oidc.go Normal file
View file

@ -0,0 +1,51 @@
package services
import (
"bytes"
"encoding/json"
"net/http"
"net/url"
log "github.com/sirupsen/logrus"
)
type OpenIDConfiguration struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
JwksUri string `json:"jwks_uri"`
EndSessionEndpoint string `json:"end_session_endpoint"`
}
func DiscoverOIDC(logger *log.Logger, oidcServer string, apiClient *http.Client) (o *OpenIDConfiguration, err error) {
var discoveryUrl *url.URL
if discoveryUrl, err = url.Parse(oidcServer); err != nil {
logger.Fatalf("could not parse oidc.server parameter value %s: %s", oidcServer, err)
} else {
discoveryUrl.Path = "/.well-known/openid-configuration"
}
var body []byte
var req *http.Request
req, err = http.NewRequest(http.MethodGet, discoveryUrl.String(), bytes.NewBuffer(body))
if err != nil {
return
}
req.Header = map[string][]string{
"Accept": {"application/json"},
}
resp, err := apiClient.Do(req)
if err != nil {
return
}
dec := json.NewDecoder(resp.Body)
o = &OpenIDConfiguration{}
err = dec.Decode(o)
if err != nil {
return
}
return
}

View file

@ -0,0 +1,19 @@
package services
import (
"crypto/rand"
log "github.com/sirupsen/logrus"
)
func GenerateKey(length int) []byte {
key := make([]byte, length)
read, err := rand.Read(key)
if err != nil {
log.Fatalf("could not generate key: %s", err)
}
if read != length {
log.Fatalf("read %d bytes, expected %d bytes", read, length)
}
return key
}

View file

@ -12,11 +12,12 @@ import (
type consentHandler struct { type consentHandler struct {
adminClient *admin.Client adminClient *admin.Client
logger *log.Logger
} }
func (c *consentHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
consentChallenge := request.URL.Query().Get("consent_challenge") consentChallenge := r.URL.Query().Get("consent_challenge")
consentRequest, err := c.adminClient.AcceptConsentRequest( consentRequest, err := h.adminClient.AcceptConsentRequest(
admin.NewAcceptConsentRequestParams().WithConsentChallenge(consentChallenge).WithBody( admin.NewAcceptConsentRequestParams().WithConsentChallenge(consentChallenge).WithBody(
&models.AcceptConsentRequest{ &models.AcceptConsentRequest{
GrantAccessTokenAudience: nil, GrantAccessTokenAudience: nil,
@ -26,14 +27,15 @@ func (c *consentHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
RememberFor: 86400, RememberFor: 86400,
}).WithTimeout(time.Second * 10)) }).WithTimeout(time.Second * 10))
if err != nil { if err != nil {
log.Panic(err) h.logger.Panic(err)
} }
writer.Header().Add("Location", *consentRequest.GetPayload().RedirectTo) w.Header().Add("Location", *consentRequest.GetPayload().RedirectTo)
writer.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
} }
func NewConsentHandler(ctx context.Context) *consentHandler { func NewConsentHandler(logger *log.Logger, ctx context.Context) *consentHandler {
return &consentHandler{ return &consentHandler{
logger: logger,
adminClient: ctx.Value(CtxAdminClient).(*admin.Client), adminClient: ctx.Value(CtxAdminClient).(*admin.Client),
} }
} }

20
idp/handlers/error.go Normal file
View file

@ -0,0 +1,20 @@
package handlers
import (
"fmt"
"net/http"
)
type errorHandler struct {
}
func (e *errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
_, _ = fmt.Fprintf(w, `
didumm %#v
`, r.URL.Query())
}
func NewErrorHandler() *errorHandler {
return &errorHandler{}
}

View file

@ -23,6 +23,7 @@ type loginHandler struct {
bundle *i18n.Bundle bundle *i18n.Bundle
messageCatalog map[string]*i18n.Message messageCatalog map[string]*i18n.Message
adminClient *admin.Client adminClient *admin.Client
logger *log.Logger
} }
type LoginInformation struct { type LoginInformation struct {
@ -30,27 +31,27 @@ type LoginInformation struct {
Password string `form:"password" validate:"required"` Password string `form:"password" validate:"required"`
} }
func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error var err error
challenge := request.URL.Query().Get("login_challenge") challenge := r.URL.Query().Get("login_challenge")
log.Debugf("received challenge %s\n", challenge) h.logger.Debugf("received challenge %s\n", challenge)
validate := validator.New() validate := validator.New()
switch request.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
// GET should render login form // GET should render login form
err = h.loginTemplate.Lookup("base").Execute(writer, map[string]interface{}{ err = h.loginTemplate.Lookup("base").Execute(w, map[string]interface{}{
"Title": "Title", "Title": "Title",
csrf.TemplateTag: csrf.TemplateField(request), csrf.TemplateTag: csrf.TemplateField(r),
"LabelEmail": "Email", "LabelEmail": "Email",
"LabelPassword": "Password", "LabelPassword": "Password",
"LabelLogin": "Login", "LabelLogin": "Login",
"errors": map[string]string{}, "errors": map[string]string{},
}) })
if err != nil { if err != nil {
log.Error(err) h.logger.Error(err)
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
break break
@ -60,23 +61,23 @@ func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reque
// validate input // validate input
decoder := form.NewDecoder() decoder := form.NewDecoder()
err = decoder.Decode(&loginInfo, request.Form) err = decoder.Decode(&loginInfo, r.Form)
if err != nil { if err != nil {
log.Error(err) h.logger.Error(err)
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
err := validate.Struct(&loginInfo) err := validate.Struct(&loginInfo)
if err != nil { if err != nil {
errors := make(map[string]string) errors := make(map[string]string)
for _, err := range err.(validator.ValidationErrors) { for _, err := range err.(validator.ValidationErrors) {
accept := request.Header.Get("Accept-Language") accept := r.Header.Get("Accept-Language")
errors[err.Field()] = h.lookupErrorMessage(err.Tag(), err.Field(), err.Value(), i18n.NewLocalizer(h.bundle, accept)) errors[err.Field()] = h.lookupErrorMessage(err.Tag(), err.Field(), err.Value(), i18n.NewLocalizer(h.bundle, accept))
} }
err = h.loginTemplate.Lookup("base").Execute(writer, map[string]interface{}{ err = h.loginTemplate.Lookup("base").Execute(w, map[string]interface{}{
"Title": "Title", "Title": "Title",
csrf.TemplateTag: csrf.TemplateField(request), csrf.TemplateTag: csrf.TemplateField(r),
"LabelEmail": "Email", "LabelEmail": "Email",
"LabelPassword": "Password", "LabelPassword": "Password",
"LabelLogin": "Login", "LabelLogin": "Login",
@ -84,8 +85,8 @@ func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reque
"errors": errors, "errors": errors,
}) })
if err != nil { if err != nil {
log.Error(err) h.logger.Error(err)
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return return
} }
return return
@ -103,13 +104,14 @@ func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reque
Subject: &subject, Subject: &subject,
}).WithTimeout(time.Second * 10)) }).WithTimeout(time.Second * 10))
if err != nil { if err != nil {
log.Panic(err) h.logger.Errorf("error getting logout requests: %v", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
} }
writer.Header().Add("Location", *loginRequest.GetPayload().RedirectTo) w.Header().Add("Location", *loginRequest.GetPayload().RedirectTo)
writer.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
break break
default: default:
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return return
} }
} }
@ -118,13 +120,13 @@ func (h *loginHandler) lookupErrorMessage(tag string, field string, value interf
var message *i18n.Message var message *i18n.Message
message, ok := h.messageCatalog[fmt.Sprintf("%s-%s", field, tag)] message, ok := h.messageCatalog[fmt.Sprintf("%s-%s", field, tag)]
if !ok { if !ok {
log.Infof("no specific error message %s-%s", field, tag) h.logger.Infof("no specific error message %s-%s", field, tag)
message, ok = h.messageCatalog[tag] message, ok = h.messageCatalog[tag]
if !ok { if !ok {
log.Infof("no specific error message %s", tag) h.logger.Infof("no specific error message %s", tag)
message, ok = h.messageCatalog["unknown"] message, ok = h.messageCatalog["unknown"]
if !ok { if !ok {
log.Error("no default translation found") h.logger.Error("no default translation found")
return tag return tag
} }
} }
@ -137,18 +139,19 @@ func (h *loginHandler) lookupErrorMessage(tag string, field string, value interf
}, },
}) })
if err != nil { if err != nil {
log.Error(err) h.logger.Error(err)
return tag return tag
} }
return translation return translation
} }
func NewLoginHandler(ctx context.Context) (*loginHandler, error) { func NewLoginHandler(logger *log.Logger, ctx context.Context) (*loginHandler, error) {
loginTemplate, err := template.ParseFiles("templates/base.html", "templates/login.html") loginTemplate, err := template.ParseFiles("templates/base.html", "templates/login.html")
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &loginHandler{ return &loginHandler{
logger: logger,
loginTemplate: loginTemplate, loginTemplate: loginTemplate,
bundle: ctx.Value(services.CtxI18nBundle).(*i18n.Bundle), bundle: ctx.Value(services.CtxI18nBundle).(*i18n.Bundle),
messageCatalog: ctx.Value(services.CtxI18nCatalog).(map[string]*i18n.Message), messageCatalog: ctx.Value(services.CtxI18nCatalog).(map[string]*i18n.Message),

58
idp/handlers/logout.go Normal file
View file

@ -0,0 +1,58 @@
package handlers
import (
"context"
"net/http"
"time"
"github.com/ory/hydra-client-go/client/admin"
log "github.com/sirupsen/logrus"
)
type logoutHandler struct {
adminClient *admin.Client
logger *log.Logger
}
func (h *logoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
challenge := r.URL.Query().Get("logout_challenge")
h.logger.Debugf("received challenge %s\n", challenge)
logoutRequest, err := h.adminClient.GetLogoutRequest(
admin.NewGetLogoutRequestParams().WithLogoutChallenge(challenge).WithTimeout(time.Second * 10))
if err != nil {
h.logger.Errorf("error getting logout requests: %v", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
h.logger.Debugf("received logout request: %#v", logoutRequest.Payload)
acceptLogoutRequest, err := h.adminClient.AcceptLogoutRequest(
admin.NewAcceptLogoutRequestParams().WithLogoutChallenge(challenge))
if err != nil {
h.logger.Errorf("error accepting logout: %v", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
w.Header().Set("Location", *acceptLogoutRequest.GetPayload().RedirectTo)
w.WriteHeader(http.StatusFound)
}
func NewLogoutHandler(logger *log.Logger, ctx context.Context) *logoutHandler {
return &logoutHandler{
logger: logger,
adminClient: ctx.Value(CtxAdminClient).(*admin.Client),
}
}
type logoutSuccessHandler struct {
}
func (l *logoutSuccessHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
panic("implement me")
}
func NewLogoutSuccessHandler() *logoutSuccessHandler {
return &logoutSuccessHandler{}
}