Refactor app, implement logout
This commit is contained in:
parent
ce1fac0e68
commit
27e225795c
14 changed files with 647 additions and 349 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -1,3 +1,4 @@
|
||||||
/.idea/
|
|
||||||
/*.toml
|
/*.toml
|
||||||
|
/.idea/
|
||||||
/certs/
|
/certs/
|
||||||
|
/sessions/
|
||||||
|
|
33
app/handlers/after_logout.go
Normal file
33
app/handlers/after_logout.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"git.cacert.org/oidc_login/app/services"
|
||||||
|
)
|
||||||
|
|
||||||
|
type afterLogoutHandler struct {
|
||||||
|
logger *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *afterLogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
session, err := services.GetSessionStore().Get(r, sessionName)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Errorf("could not get session: %v", err)
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session.Options.MaxAge = -1
|
||||||
|
if err = session.Save(r, w); err != nil {
|
||||||
|
h.logger.Errorf("could not save session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Location", "/")
|
||||||
|
w.WriteHeader(http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAfterLogoutHandler(logger *logrus.Logger) *afterLogoutHandler {
|
||||||
|
return &afterLogoutHandler{logger: logger}
|
||||||
|
}
|
49
app/handlers/common.go
Normal file
49
app/handlers/common.go
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
"git.cacert.org/oidc_login/app/services"
|
||||||
|
commonServices "git.cacert.org/oidc_login/common/services"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionName = "resource_session"
|
||||||
|
|
||||||
|
func Authenticate(oauth2Config *oauth2.Config, clientId string) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
session, err := services.GetSessionStore().Get(r, sessionName)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := session.Values[sessionKeyUserId]; ok {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session.Values[sessionRedirectTarget] = r.URL.String()
|
||||||
|
if err = session.Save(r, w); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var authUrl *url.URL
|
||||||
|
if authUrl, err = url.Parse(oauth2Config.Endpoint.AuthURL); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
queryValues := authUrl.Query()
|
||||||
|
queryValues.Set("client_id", clientId)
|
||||||
|
queryValues.Set("response_type", "code")
|
||||||
|
queryValues.Set("scope", "openid offline_access profile email")
|
||||||
|
queryValues.Set("state", base64.URLEncoding.EncodeToString(commonServices.GenerateKey(8)))
|
||||||
|
authUrl.RawQuery = queryValues.Encode()
|
||||||
|
|
||||||
|
w.Header().Set("Location", authUrl.String())
|
||||||
|
w.WriteHeader(http.StatusFound)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
79
app/handlers/index.go
Normal file
79
app/handlers/index.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"git.cacert.org/oidc_login/app/services"
|
||||||
|
)
|
||||||
|
|
||||||
|
type indexHandler struct {
|
||||||
|
logoutUrl string
|
||||||
|
serverAddr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *indexHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
if request.Method != http.MethodGet {
|
||||||
|
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.URL.Path != "/" {
|
||||||
|
http.NotFound(writer, request)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
page, err := template.New("").Parse(`
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head><title>Auth test</title></head>
|
||||||
|
<body>
|
||||||
|
<h1>Hello {{ .User }}</h1>
|
||||||
|
<p>This is an authorization protected resource</p>
|
||||||
|
<a href="{{ .LogoutURL }}">Logout</a>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session, err := services.GetSessionStore().Get(request, sessionName)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logoutUrl, err := url.Parse(h.logoutUrl)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var user string
|
||||||
|
var ok bool
|
||||||
|
if user, ok = session.Values[sessionKeyUsername].(string); ok {
|
||||||
|
|
||||||
|
}
|
||||||
|
if idToken, ok := session.Values[sessionKeyIdToken].(string); ok {
|
||||||
|
logoutUrl.RawQuery = url.Values{
|
||||||
|
"id_token_hint": []string{idToken},
|
||||||
|
"post_logout_redirect_uri": []string{fmt.Sprintf("https://%s/after-logout", h.serverAddr)},
|
||||||
|
}.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.Header().Add("Content-Type", "text/html")
|
||||||
|
err = page.Execute(writer, map[string]interface{}{
|
||||||
|
"User": user,
|
||||||
|
"LogoutURL": logoutUrl.String(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIndexHandler(logoutUrl string, serverAddr string) *indexHandler {
|
||||||
|
return &indexHandler{logoutUrl: logoutUrl, serverAddr: serverAddr}
|
||||||
|
}
|
117
app/handlers/oidc_callback.go
Normal file
117
app/handlers/oidc_callback.go
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/go-openapi/runtime/client"
|
||||||
|
"github.com/lestrrat-go/jwx/jwk"
|
||||||
|
"github.com/lestrrat-go/jwx/jwt"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
"git.cacert.org/oidc_login/app/services"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sessionKeyAccessToken = iota
|
||||||
|
sessionKeyRefreshToken
|
||||||
|
sessionKeyIdToken
|
||||||
|
sessionKeyUserId
|
||||||
|
sessionKeyRoles
|
||||||
|
sessionKeyEmail
|
||||||
|
sessionKeyUsername
|
||||||
|
sessionRedirectTarget
|
||||||
|
)
|
||||||
|
|
||||||
|
type oidcCallbackHandler struct {
|
||||||
|
keySet *jwk.Set
|
||||||
|
oauth2Config *oauth2.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *oidcCallbackHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
if request.Method != http.MethodGet {
|
||||||
|
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.URL.Path != "/callback" {
|
||||||
|
http.NotFound(writer, request)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code := request.URL.Query().Get("code")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
httpClient, err := client.TLSClient(client.TLSClientOptions{InsecureSkipVerify: true})
|
||||||
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||||
|
|
||||||
|
tok, err := c.oauth2Config.Exchange(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Error(err)
|
||||||
|
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := services.GetSessionStore().Get(request, "resource_session")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session.Values[sessionKeyAccessToken] = tok.AccessToken
|
||||||
|
session.Values[sessionKeyRefreshToken] = tok.RefreshToken
|
||||||
|
session.Values[sessionKeyIdToken] = tok.Extra("id_token").(string)
|
||||||
|
|
||||||
|
idToken := tok.Extra("id_token")
|
||||||
|
if parsedIdToken, err := jwt.ParseString(idToken.(string), jwt.WithKeySet(c.keySet), jwt.WithOpenIDClaims()); err != nil {
|
||||||
|
logrus.Error(err)
|
||||||
|
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
logrus.Infof(`
|
||||||
|
ID Token
|
||||||
|
========
|
||||||
|
|
||||||
|
Subject: %s
|
||||||
|
Audience: %s
|
||||||
|
Issued at: %s
|
||||||
|
Issued by: %s
|
||||||
|
Not valid before: %s
|
||||||
|
Not valid after: %s
|
||||||
|
|
||||||
|
`,
|
||||||
|
parsedIdToken.Subject(),
|
||||||
|
parsedIdToken.Audience(),
|
||||||
|
parsedIdToken.IssuedAt(),
|
||||||
|
parsedIdToken.Issuer(),
|
||||||
|
parsedIdToken.NotBefore(),
|
||||||
|
parsedIdToken.Expiration(),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.Values[sessionKeyUserId] = parsedIdToken.Subject()
|
||||||
|
|
||||||
|
if roles, ok := parsedIdToken.Get("Groups"); ok {
|
||||||
|
session.Values[sessionKeyRoles] = roles
|
||||||
|
}
|
||||||
|
if username, ok := parsedIdToken.Get("Username"); ok {
|
||||||
|
session.Values[sessionKeyUsername] = username
|
||||||
|
}
|
||||||
|
if email, ok := parsedIdToken.Get("Email"); ok {
|
||||||
|
session.Values[sessionKeyEmail] = email
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = session.Save(request, writer); err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if redirectTarget, ok := session.Values[sessionRedirectTarget]; ok {
|
||||||
|
writer.Header().Set("Location", redirectTarget.(string))
|
||||||
|
} else {
|
||||||
|
writer.Header().Set("Location", "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.WriteHeader(http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCallbackHandler(keySet *jwk.Set, oauth2Config *oauth2.Config) *oidcCallbackHandler {
|
||||||
|
return &oidcCallbackHandler{keySet: keySet, oauth2Config: oauth2Config}
|
||||||
|
}
|
25
app/services/session.go
Normal file
25
app/services/session.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var store *sessions.FilesystemStore
|
||||||
|
|
||||||
|
func InitSessionStore(logger *log.Logger, sessionPath string, keys ...[]byte) {
|
||||||
|
store = sessions.NewFilesystemStore(sessionPath, keys...)
|
||||||
|
if _, err := os.Stat(sessionPath); err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
if err = os.MkdirAll(sessionPath, 0700); err != nil {
|
||||||
|
logger.Fatalf("could not create session store directory: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSessionStore() *sessions.FilesystemStore {
|
||||||
|
return store
|
||||||
|
}
|
422
cmd/app/main.go
422
cmd/app/main.go
|
@ -1,111 +1,117 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"fmt"
|
||||||
"html/template"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-openapi/runtime/client"
|
|
||||||
"github.com/gorilla/sessions"
|
|
||||||
"github.com/knadh/koanf"
|
"github.com/knadh/koanf"
|
||||||
jsonParser "github.com/knadh/koanf/parsers/json"
|
"github.com/knadh/koanf/parsers/toml"
|
||||||
"github.com/knadh/koanf/providers/confmap"
|
"github.com/knadh/koanf/providers/confmap"
|
||||||
"github.com/knadh/koanf/providers/env"
|
"github.com/knadh/koanf/providers/env"
|
||||||
"github.com/knadh/koanf/providers/file"
|
"github.com/knadh/koanf/providers/file"
|
||||||
|
"github.com/knadh/koanf/providers/posflag"
|
||||||
"github.com/lestrrat-go/jwx/jwk"
|
"github.com/lestrrat-go/jwx/jwk"
|
||||||
"github.com/lestrrat-go/jwx/jwt"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
flag "github.com/spf13/pflag"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
|
||||||
|
|
||||||
type OpenIDConfiguration struct {
|
"git.cacert.org/oidc_login/app/handlers"
|
||||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
"git.cacert.org/oidc_login/app/services"
|
||||||
TokenEndpoint string `json:"token_endpoint"`
|
commonHandlers "git.cacert.org/oidc_login/common/handlers"
|
||||||
JwksUri string `json:"jwks_uri"`
|
commonServices "git.cacert.org/oidc_login/common/services"
|
||||||
EndSessionEndpoint string `json:"end_session_endpoint"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
sessionStore *sessions.FilesystemStore
|
|
||||||
k = koanf.New(".")
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
sessionKeyAccessToken = iota
|
|
||||||
sessionKeyRefreshToken
|
|
||||||
sessionKeyIdToken
|
|
||||||
sessionKeyUserId
|
|
||||||
sessionKeyRoles
|
|
||||||
sessionKeyEmail
|
|
||||||
sessionKeyUsername
|
|
||||||
sessionRedirectTarget
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if err := k.Load(file.Provider("resourceapp.json"), jsonParser.Parser()); err != nil && !os.IsNotExist(err) {
|
f := flag.NewFlagSet("config", flag.ContinueOnError)
|
||||||
|
f.Usage = func() {
|
||||||
|
fmt.Println(f.FlagUsages())
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
f.StringSlice("conf", []string{"resource_app.toml"}, "path to one or more .toml files")
|
||||||
|
logger := log.New()
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if err = f.Parse(os.Args[1:]); err != nil {
|
||||||
|
logger.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := koanf.New(".")
|
||||||
|
|
||||||
|
_ = config.Load(confmap.Provider(map[string]interface{}{
|
||||||
|
"server.port": 4000,
|
||||||
|
"server.name": "app.cacert.localhost",
|
||||||
|
"server.key": "certs/app.cacert.localhost.key",
|
||||||
|
"server.certificate": "certs/app.cacert.localhost.crt.pem",
|
||||||
|
"oidc.server": "https://auth.cacert.localhost:4444/",
|
||||||
|
"session.path": "sessions/app",
|
||||||
|
}, "."), nil)
|
||||||
|
cFiles, _ := f.GetStringSlice("conf")
|
||||||
|
for _, c := range cFiles {
|
||||||
|
if err := config.Load(file.Provider(c), toml.Parser()); err != nil {
|
||||||
|
logger.Fatalf("error loading config file: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := config.Load(posflag.Provider(f, ".", config), nil); err != nil {
|
||||||
|
logger.Fatalf("error loading configuration: %s", err)
|
||||||
|
}
|
||||||
|
if err := config.Load(file.Provider("resource_app.toml"), toml.Parser()); err != nil && !os.IsNotExist(err) {
|
||||||
log.Fatalf("error loading config: %v", err)
|
log.Fatalf("error loading config: %v", err)
|
||||||
}
|
}
|
||||||
const prefix = "RESOURCEAPP_"
|
const prefix = "RESOURCE_APP_"
|
||||||
if err := k.Load(env.Provider(prefix, ".", func(s string) string {
|
if err := config.Load(env.Provider(prefix, ".", func(s string) string {
|
||||||
return strings.Replace(strings.ToLower(
|
return strings.Replace(strings.ToLower(
|
||||||
strings.TrimPrefix(s, prefix)), "_", ".", -1)
|
strings.TrimPrefix(s, prefix)), "_", ".", -1)
|
||||||
}), nil); err != nil {
|
}), nil); err != nil {
|
||||||
log.Fatalf("error loading config: %v", err)
|
log.Fatalf("error loading config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcServer := k.MustString("oidc.server")
|
oidcServer := config.MustString("oidc.server")
|
||||||
oidcClientId := k.MustString("oidc.client-id")
|
oidcClientId := config.MustString("oidc.client-id")
|
||||||
oidcClientSecret := k.MustString("oidc.client-secret")
|
oidcClientSecret := config.MustString("oidc.client-secret")
|
||||||
|
|
||||||
sessionPath := k.MustString("session.path")
|
sessionPath := config.MustString("session.path")
|
||||||
sessionAuthKey, err := base64.StdEncoding.DecodeString(k.String("session.auth-key"))
|
sessionAuthKey, err := base64.StdEncoding.DecodeString(config.String("session.auth-key"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("could not decode session auth key: %s", err)
|
log.Fatalf("could not decode session auth key: %s", err)
|
||||||
}
|
}
|
||||||
sessionEncKey, err := base64.StdEncoding.DecodeString(k.String("session.enc-key"))
|
sessionEncKey, err := base64.StdEncoding.DecodeString(config.String("session.enc-key"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("could not decode session encryption key: %s", err)
|
log.Fatalf("could not decode session encryption key: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
generated := false
|
generated := false
|
||||||
if len(sessionAuthKey) != 64 {
|
if len(sessionAuthKey) != 64 {
|
||||||
sessionAuthKey = generateKey(64)
|
sessionAuthKey = commonServices.GenerateKey(64)
|
||||||
generated = true
|
generated = true
|
||||||
}
|
}
|
||||||
if len(sessionEncKey) != 32 {
|
if len(sessionEncKey) != 32 {
|
||||||
sessionEncKey = generateKey(32)
|
sessionEncKey = commonServices.GenerateKey(32)
|
||||||
generated = true
|
generated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if generated {
|
if generated {
|
||||||
_ = k.Load(confmap.Provider(map[string]interface{}{
|
_ = config.Load(confmap.Provider(map[string]interface{}{
|
||||||
"session.auth-key": sessionAuthKey,
|
"session.auth-key": sessionAuthKey,
|
||||||
"session.enc-key": sessionEncKey,
|
"session.enc-key": sessionEncKey,
|
||||||
}, "."), nil)
|
}, "."), nil)
|
||||||
jsonData, err := k.Marshal(jsonParser.Parser())
|
tomlData, err := config.Marshal(toml.Parser())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("could not encode session config")
|
log.Fatalf("could not encode session config")
|
||||||
}
|
}
|
||||||
log.Infof("put the following in your resourceapp.json:\n%s", string(jsonData))
|
log.Infof("put the following in your resource_app.toml:\n%s", string(tomlData))
|
||||||
}
|
}
|
||||||
|
|
||||||
var discoveryResponse OpenIDConfiguration
|
var discoveryResponse *commonServices.OpenIDConfiguration
|
||||||
var discoveryUrl *url.URL
|
apiClient := &http.Client{}
|
||||||
|
if discoveryResponse, err = commonServices.DiscoverOIDC(logger, oidcServer, apiClient); err != nil {
|
||||||
if discoveryUrl, err = url.Parse(oidcServer); err != nil {
|
|
||||||
log.Fatalf("could not parse oidc.server parameter value %s: %s", oidcServer, err)
|
|
||||||
} else {
|
|
||||||
discoveryUrl.Path = "/.well-known/openid-configuration"
|
|
||||||
}
|
|
||||||
apiClient, err := client.TLSClient(client.TLSClientOptions{InsecureSkipVerify: true})
|
|
||||||
if err := discoverOidc(discoveryUrl, apiClient, &discoveryResponse); err != nil {
|
|
||||||
log.Fatalf("OpenID Connect discovery failed: %s", err)
|
log.Fatalf("OpenID Connect discovery failed: %s", err)
|
||||||
}
|
}
|
||||||
oauth2Config := &oauth2.Config{
|
oauth2Config := &oauth2.Config{
|
||||||
|
@ -122,251 +128,69 @@ func main() {
|
||||||
log.Fatalf("could not fetch JWKs: %s", err)
|
log.Fatalf("could not fetch JWKs: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = os.Stat(sessionPath); err != nil {
|
services.InitSessionStore(logger, sessionPath, sessionAuthKey, sessionEncKey)
|
||||||
if os.IsNotExist(err) {
|
|
||||||
if err = os.MkdirAll(sessionPath, 0700); err != nil {
|
authMiddleware := handlers.Authenticate(oauth2Config, config.MustString("oidc.client-id"))
|
||||||
log.Fatalf("could not create session store directory: %s", err)
|
|
||||||
}
|
serverAddr := fmt.Sprintf("%s:%d", config.String("server.name"), config.Int("server.port"))
|
||||||
|
indexHandler := handlers.NewIndexHandler(discoveryResponse.EndSessionEndpoint, serverAddr)
|
||||||
|
callbackHandler := handlers.NewCallbackHandler(keySet, oauth2Config)
|
||||||
|
afterLogoutHandler := handlers.NewAfterLogoutHandler(logger)
|
||||||
|
|
||||||
|
router := http.NewServeMux()
|
||||||
|
router.Handle("/", authMiddleware(indexHandler))
|
||||||
|
router.Handle("/callback", callbackHandler)
|
||||||
|
router.Handle("/after-logout", afterLogoutHandler)
|
||||||
|
|
||||||
|
nextRequestId := func() string {
|
||||||
|
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing := commonHandlers.Tracing(nextRequestId)
|
||||||
|
logging := commonHandlers.Logging(logger)
|
||||||
|
hsts := commonHandlers.EnableHSTS()
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: config.String("server.name"),
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: serverAddr,
|
||||||
|
Handler: tracing(logging(hsts(router))),
|
||||||
|
ReadTimeout: 5 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
IdleTimeout: 15 * time.Second,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan bool)
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, os.Interrupt)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-quit
|
||||||
|
logger.Infoln("Server is shutting down...")
|
||||||
|
atomic.StoreInt32(&commonHandlers.Healthy, 0)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
server.SetKeepAlivesEnabled(false)
|
||||||
|
if err := server.Shutdown(ctx); err != nil {
|
||||||
|
logger.Fatalf("Could not gracefully shutdown the server: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
close(done)
|
||||||
sessionStore = sessions.NewFilesystemStore(sessionPath, sessionAuthKey, sessionEncKey)
|
}()
|
||||||
|
|
||||||
http.Handle("/", authenticate(oauth2Config)(NewIndexPage(discoveryResponse.EndSessionEndpoint)))
|
logger.Infof("Server is ready to handle requests at https://%s/", server.Addr)
|
||||||
http.Handle("/callback", NewCallbackHandler(keySet, oauth2Config))
|
atomic.StoreInt32(&commonHandlers.Healthy, 1)
|
||||||
|
if err := server.ListenAndServeTLS(
|
||||||
err = http.ListenAndServe(":4000", http.DefaultServeMux)
|
config.String("server.certificate"), config.String("server.key"),
|
||||||
if err != nil {
|
); err != nil && err != http.ErrServerClosed {
|
||||||
log.Fatal(err)
|
logger.Fatalf("Could not listen on %s: %v\n", server.Addr, err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
<-done
|
||||||
func generateKey(length int) []byte {
|
logger.Infoln("Server stopped")
|
||||||
key := make([]byte, length)
|
|
||||||
read, err := rand.Read(key)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("could not generate key: %s", err)
|
|
||||||
}
|
|
||||||
if read != length {
|
|
||||||
log.Fatalf("read %d bytes, expected %d bytes", read, length)
|
|
||||||
}
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
|
|
||||||
func discoverOidc(discoveryUrl *url.URL, apiClient *http.Client, o *OpenIDConfiguration) error {
|
|
||||||
var body []byte
|
|
||||||
req, err := http.NewRequest(http.MethodGet, discoveryUrl.String(), bytes.NewBuffer(body))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header = map[string][]string{
|
|
||||||
"Accept": {"application/json"},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := apiClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dec := json.NewDecoder(resp.Body)
|
|
||||||
err = dec.Decode(o)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type callbackHandler struct {
|
|
||||||
keySet *jwk.Set
|
|
||||||
oauth2Config *oauth2.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *callbackHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
|
||||||
if request.Method != http.MethodGet {
|
|
||||||
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if request.URL.Path != "/callback" {
|
|
||||||
http.NotFound(writer, request)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
code := request.URL.Query().Get("code")
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
httpClient, err := client.TLSClient(client.TLSClientOptions{InsecureSkipVerify: true})
|
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
|
||||||
|
|
||||||
tok, err := c.oauth2Config.Exchange(ctx, code)
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := sessionStore.Get(request, "resource_session")
|
|
||||||
if err != nil {
|
|
||||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
session.Values[sessionKeyAccessToken] = tok.AccessToken
|
|
||||||
session.Values[sessionKeyRefreshToken] = tok.RefreshToken
|
|
||||||
session.Values[sessionKeyIdToken] = tok.Extra("id_token").(string)
|
|
||||||
|
|
||||||
idToken := tok.Extra("id_token")
|
|
||||||
if parsedIdToken, err := jwt.ParseString(idToken.(string), jwt.WithKeySet(c.keySet), jwt.WithOpenIDClaims()); err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
log.Infof(`
|
|
||||||
ID Token
|
|
||||||
========
|
|
||||||
|
|
||||||
Subject: %s
|
|
||||||
Audience: %s
|
|
||||||
Issued at: %s
|
|
||||||
Issued by: %s
|
|
||||||
Not valid before: %s
|
|
||||||
Not valid after: %s
|
|
||||||
|
|
||||||
`,
|
|
||||||
parsedIdToken.Subject(),
|
|
||||||
parsedIdToken.Audience(),
|
|
||||||
parsedIdToken.IssuedAt(),
|
|
||||||
parsedIdToken.Issuer(),
|
|
||||||
parsedIdToken.NotBefore(),
|
|
||||||
parsedIdToken.Expiration(),
|
|
||||||
)
|
|
||||||
|
|
||||||
session.Values[sessionKeyUserId] = parsedIdToken.Subject()
|
|
||||||
|
|
||||||
if roles, ok := parsedIdToken.Get("Groups"); ok {
|
|
||||||
session.Values[sessionKeyRoles] = roles
|
|
||||||
}
|
|
||||||
if username, ok := parsedIdToken.Get("Username"); ok {
|
|
||||||
session.Values[sessionKeyUsername] = username
|
|
||||||
}
|
|
||||||
if email, ok := parsedIdToken.Get("Email"); ok {
|
|
||||||
session.Values[sessionKeyEmail] = email
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err = session.Save(request, writer); err != nil {
|
|
||||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if redirectTarget, ok := session.Values[sessionRedirectTarget]; ok {
|
|
||||||
writer.Header().Set("Location", redirectTarget.(string))
|
|
||||||
} else {
|
|
||||||
writer.Header().Set("Location", "/")
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.WriteHeader(http.StatusFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCallbackHandler(keySet *jwk.Set, oauth2Config *oauth2.Config) *callbackHandler {
|
|
||||||
return &callbackHandler{keySet: keySet, oauth2Config: oauth2Config}
|
|
||||||
}
|
|
||||||
|
|
||||||
func authenticate(oauth2Config *oauth2.Config) func(http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
session, err := sessionStore.Get(r, "resource_session")
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, ok := session.Values[sessionKeyUserId]; ok {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
session.Values[sessionRedirectTarget] = r.URL.String()
|
|
||||||
if err = session.Save(r, w); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var authUrl *url.URL
|
|
||||||
if authUrl, err = url.Parse(oauth2Config.Endpoint.AuthURL); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
queryValues := authUrl.Query()
|
|
||||||
queryValues.Set("client_id", k.String("oidc.client-id"))
|
|
||||||
queryValues.Set("response_type", "code")
|
|
||||||
queryValues.Set("scope", "openid offline")
|
|
||||||
queryValues.Set("state", base64.URLEncoding.EncodeToString(generateKey(8)))
|
|
||||||
authUrl.RawQuery = queryValues.Encode()
|
|
||||||
|
|
||||||
w.Header().Set("Location", authUrl.String())
|
|
||||||
w.WriteHeader(http.StatusFound)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type indexHandler struct {
|
|
||||||
logoutUrl string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i indexHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
|
||||||
if request.Method != http.MethodGet {
|
|
||||||
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if request.URL.Path != "/" {
|
|
||||||
http.NotFound(writer, request)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writer.WriteHeader(http.StatusOK)
|
|
||||||
|
|
||||||
page, err := template.New("").Parse(`
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head><title>Auth test</title></head>
|
|
||||||
<body>
|
|
||||||
<h1>Hello {{ .User }}</h1>
|
|
||||||
<p>This is an authorization protected resource</p>
|
|
||||||
<a href="{{ .LogoutURL }}">Logout</a>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
session, err := sessionStore.Get(request, "resource_session")
|
|
||||||
if err != nil {
|
|
||||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logoutUrl, err := url.Parse(i.logoutUrl)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var user string
|
|
||||||
var ok bool
|
|
||||||
if user, ok = session.Values[sessionKeyUsername].(string); ok {
|
|
||||||
|
|
||||||
}
|
|
||||||
if idToken, ok := session.Values[sessionKeyIdToken].(string); ok {
|
|
||||||
logoutUrl.RawQuery = url.Values{
|
|
||||||
"id_token_hint": []string{idToken},
|
|
||||||
"post_logout_redirect_uri": []string{"/logged_out"},
|
|
||||||
}.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.Header().Add("Content-Type", "text/html")
|
|
||||||
err = page.Execute(writer, map[string]interface{}{
|
|
||||||
"User": user,
|
|
||||||
"LogoutURL": logoutUrl.String(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewIndexPage(logoutUrl string) *indexHandler {
|
|
||||||
return &indexHandler{logoutUrl: logoutUrl}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,15 +9,16 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-openapi/runtime/client"
|
"github.com/go-openapi/runtime/client"
|
||||||
"github.com/gorilla/csrf"
|
"github.com/gorilla/csrf"
|
||||||
"github.com/knadh/koanf"
|
"github.com/knadh/koanf"
|
||||||
"github.com/knadh/koanf/parsers/json"
|
|
||||||
"github.com/knadh/koanf/parsers/toml"
|
"github.com/knadh/koanf/parsers/toml"
|
||||||
"github.com/knadh/koanf/providers/confmap"
|
"github.com/knadh/koanf/providers/confmap"
|
||||||
|
"github.com/knadh/koanf/providers/env"
|
||||||
"github.com/knadh/koanf/providers/file"
|
"github.com/knadh/koanf/providers/file"
|
||||||
"github.com/knadh/koanf/providers/posflag"
|
"github.com/knadh/koanf/providers/posflag"
|
||||||
hydra "github.com/ory/hydra-client-go/client"
|
hydra "github.com/ory/hydra-client-go/client"
|
||||||
|
@ -45,12 +46,6 @@ func main() {
|
||||||
|
|
||||||
config := koanf.New(".")
|
config := koanf.New(".")
|
||||||
|
|
||||||
cFiles, _ := f.GetStringSlice("conf")
|
|
||||||
for _, c := range cFiles {
|
|
||||||
if err := config.Load(file.Provider(c), toml.Parser()); err != nil {
|
|
||||||
logger.Fatalf("error loading config file: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = config.Load(confmap.Provider(map[string]interface{}{
|
_ = config.Load(confmap.Provider(map[string]interface{}{
|
||||||
"server.port": 3000,
|
"server.port": 3000,
|
||||||
"server.name": "login.cacert.localhost",
|
"server.name": "login.cacert.localhost",
|
||||||
|
@ -58,10 +53,22 @@ func main() {
|
||||||
"server.certificate": "certs/idp.cacert.localhost.crt.pem",
|
"server.certificate": "certs/idp.cacert.localhost.crt.pem",
|
||||||
"admin.url": "https://hydra.cacert.localhost:4445/",
|
"admin.url": "https://hydra.cacert.localhost:4445/",
|
||||||
}, "."), nil)
|
}, "."), nil)
|
||||||
_ = config.Load(file.Provider("idp.json"), json.Parser())
|
cFiles, _ := f.GetStringSlice("conf")
|
||||||
|
for _, c := range cFiles {
|
||||||
|
if err := config.Load(file.Provider(c), toml.Parser()); err != nil {
|
||||||
|
logger.Fatalf("error loading config file: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := config.Load(posflag.Provider(f, ".", config), nil); err != nil {
|
if err := config.Load(posflag.Provider(f, ".", config), nil); err != nil {
|
||||||
logger.Fatalf("error loading configuration: %s", err)
|
logger.Fatalf("error loading configuration: %s", err)
|
||||||
}
|
}
|
||||||
|
const prefix = "IDP_"
|
||||||
|
if err := config.Load(env.Provider(prefix, ".", func(s string) string {
|
||||||
|
return strings.Replace(strings.ToLower(
|
||||||
|
strings.TrimPrefix(s, prefix)), "_", ".", -1)
|
||||||
|
}), nil); err != nil {
|
||||||
|
log.Fatalf("error loading config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
logger.Infoln("Server is starting")
|
logger.Infoln("Server is starting")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -75,15 +82,21 @@ func main() {
|
||||||
adminClient := hydra.New(clientTransport, nil)
|
adminClient := hydra.New(clientTransport, nil)
|
||||||
|
|
||||||
handlerContext := context.WithValue(ctx, handlers.CtxAdminClient, adminClient.Admin)
|
handlerContext := context.WithValue(ctx, handlers.CtxAdminClient, adminClient.Admin)
|
||||||
loginHandler, err := handlers.NewLoginHandler(handlerContext)
|
loginHandler, err := handlers.NewLoginHandler(logger, handlerContext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("error initializing login handler: %v", err)
|
logger.Fatalf("error initializing login handler: %v", err)
|
||||||
}
|
}
|
||||||
consentHandler := handlers.NewConsentHandler(handlerContext)
|
consentHandler := handlers.NewConsentHandler(logger, handlerContext)
|
||||||
|
logoutHandler := handlers.NewLogoutHandler(logger, handlerContext)
|
||||||
|
logoutSuccessHandler := handlers.NewLogoutSuccessHandler()
|
||||||
|
errorHandler := handlers.NewErrorHandler()
|
||||||
|
|
||||||
router := http.NewServeMux()
|
router := http.NewServeMux()
|
||||||
router.Handle("/login", loginHandler)
|
router.Handle("/login", loginHandler)
|
||||||
router.Handle("/consent", consentHandler)
|
router.Handle("/consent", consentHandler)
|
||||||
|
router.Handle("/logout", logoutHandler)
|
||||||
|
router.Handle("/error", errorHandler)
|
||||||
|
router.Handle("/logout-successful", logoutSuccessHandler)
|
||||||
router.Handle("/health", commonHandlers.NewHealthHandler())
|
router.Handle("/health", commonHandlers.NewHealthHandler())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -94,23 +107,27 @@ func main() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("could not parse CSRF key bytes: %v", err)
|
logger.Fatalf("could not parse CSRF key bytes: %v", err)
|
||||||
}
|
}
|
||||||
handler := csrf.Protect(csrfKey, csrf.Secure(true))(router)
|
|
||||||
|
|
||||||
nextRequestId := func() string {
|
nextRequestId := func() string {
|
||||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tracing := commonHandlers.Tracing(nextRequestId)
|
||||||
|
logging := commonHandlers.Logging(logger)
|
||||||
|
hsts := commonHandlers.EnableHSTS()
|
||||||
|
csrfProtect := csrf.Protect(
|
||||||
|
csrfKey,
|
||||||
|
csrf.Secure(true),
|
||||||
|
csrf.SameSite(csrf.SameSiteStrictMode),
|
||||||
|
csrf.MaxAge(600))
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
ServerName: config.String("server.name"),
|
ServerName: config.String("server.name"),
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
}
|
}
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Addr: fmt.Sprintf("%s:%d", config.String("server.name"), config.Int("server.port")),
|
Addr: fmt.Sprintf("%s:%d", config.String("server.name"), config.Int("server.port")),
|
||||||
Handler: commonHandlers.Tracing(nextRequestId)(
|
Handler: tracing(logging(hsts(csrfProtect(router)))),
|
||||||
commonHandlers.Logging(logger)(
|
|
||||||
commonHandlers.EnableHSTS()(handler),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
ReadTimeout: 5 * time.Second,
|
ReadTimeout: 5 * time.Second,
|
||||||
WriteTimeout: 10 * time.Second,
|
WriteTimeout: 10 * time.Second,
|
||||||
IdleTimeout: 15 * time.Second,
|
IdleTimeout: 15 * time.Second,
|
||||||
|
|
51
common/services/oidc.go
Normal file
51
common/services/oidc.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenIDConfiguration struct {
|
||||||
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||||
|
TokenEndpoint string `json:"token_endpoint"`
|
||||||
|
JwksUri string `json:"jwks_uri"`
|
||||||
|
EndSessionEndpoint string `json:"end_session_endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func DiscoverOIDC(logger *log.Logger, oidcServer string, apiClient *http.Client) (o *OpenIDConfiguration, err error) {
|
||||||
|
var discoveryUrl *url.URL
|
||||||
|
|
||||||
|
if discoveryUrl, err = url.Parse(oidcServer); err != nil {
|
||||||
|
logger.Fatalf("could not parse oidc.server parameter value %s: %s", oidcServer, err)
|
||||||
|
} else {
|
||||||
|
discoveryUrl.Path = "/.well-known/openid-configuration"
|
||||||
|
}
|
||||||
|
|
||||||
|
var body []byte
|
||||||
|
var req *http.Request
|
||||||
|
req, err = http.NewRequest(http.MethodGet, discoveryUrl.String(), bytes.NewBuffer(body))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header = map[string][]string{
|
||||||
|
"Accept": {"application/json"},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := apiClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dec := json.NewDecoder(resp.Body)
|
||||||
|
o = &OpenIDConfiguration{}
|
||||||
|
err = dec.Decode(o)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
19
common/services/security.go
Normal file
19
common/services/security.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GenerateKey(length int) []byte {
|
||||||
|
key := make([]byte, length)
|
||||||
|
read, err := rand.Read(key)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("could not generate key: %s", err)
|
||||||
|
}
|
||||||
|
if read != length {
|
||||||
|
log.Fatalf("read %d bytes, expected %d bytes", read, length)
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
|
@ -12,11 +12,12 @@ import (
|
||||||
|
|
||||||
type consentHandler struct {
|
type consentHandler struct {
|
||||||
adminClient *admin.Client
|
adminClient *admin.Client
|
||||||
|
logger *log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *consentHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
consentChallenge := request.URL.Query().Get("consent_challenge")
|
consentChallenge := r.URL.Query().Get("consent_challenge")
|
||||||
consentRequest, err := c.adminClient.AcceptConsentRequest(
|
consentRequest, err := h.adminClient.AcceptConsentRequest(
|
||||||
admin.NewAcceptConsentRequestParams().WithConsentChallenge(consentChallenge).WithBody(
|
admin.NewAcceptConsentRequestParams().WithConsentChallenge(consentChallenge).WithBody(
|
||||||
&models.AcceptConsentRequest{
|
&models.AcceptConsentRequest{
|
||||||
GrantAccessTokenAudience: nil,
|
GrantAccessTokenAudience: nil,
|
||||||
|
@ -26,14 +27,15 @@ func (c *consentHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
|
||||||
RememberFor: 86400,
|
RememberFor: 86400,
|
||||||
}).WithTimeout(time.Second * 10))
|
}).WithTimeout(time.Second * 10))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic(err)
|
h.logger.Panic(err)
|
||||||
}
|
}
|
||||||
writer.Header().Add("Location", *consentRequest.GetPayload().RedirectTo)
|
w.Header().Add("Location", *consentRequest.GetPayload().RedirectTo)
|
||||||
writer.WriteHeader(http.StatusFound)
|
w.WriteHeader(http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConsentHandler(ctx context.Context) *consentHandler {
|
func NewConsentHandler(logger *log.Logger, ctx context.Context) *consentHandler {
|
||||||
return &consentHandler{
|
return &consentHandler{
|
||||||
|
logger: logger,
|
||||||
adminClient: ctx.Value(CtxAdminClient).(*admin.Client),
|
adminClient: ctx.Value(CtxAdminClient).(*admin.Client),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
20
idp/handlers/error.go
Normal file
20
idp/handlers/error.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type errorHandler struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
_, _ = fmt.Fprintf(w, `
|
||||||
|
didumm %#v
|
||||||
|
`, r.URL.Query())
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewErrorHandler() *errorHandler {
|
||||||
|
return &errorHandler{}
|
||||||
|
}
|
|
@ -23,6 +23,7 @@ type loginHandler struct {
|
||||||
bundle *i18n.Bundle
|
bundle *i18n.Bundle
|
||||||
messageCatalog map[string]*i18n.Message
|
messageCatalog map[string]*i18n.Message
|
||||||
adminClient *admin.Client
|
adminClient *admin.Client
|
||||||
|
logger *log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type LoginInformation struct {
|
type LoginInformation struct {
|
||||||
|
@ -30,27 +31,27 @@ type LoginInformation struct {
|
||||||
Password string `form:"password" validate:"required"`
|
Password string `form:"password" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
var err error
|
var err error
|
||||||
challenge := request.URL.Query().Get("login_challenge")
|
challenge := r.URL.Query().Get("login_challenge")
|
||||||
log.Debugf("received challenge %s\n", challenge)
|
h.logger.Debugf("received challenge %s\n", challenge)
|
||||||
validate := validator.New()
|
validate := validator.New()
|
||||||
|
|
||||||
switch request.Method {
|
switch r.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
// GET should render login form
|
// GET should render login form
|
||||||
|
|
||||||
err = h.loginTemplate.Lookup("base").Execute(writer, map[string]interface{}{
|
err = h.loginTemplate.Lookup("base").Execute(w, map[string]interface{}{
|
||||||
"Title": "Title",
|
"Title": "Title",
|
||||||
csrf.TemplateTag: csrf.TemplateField(request),
|
csrf.TemplateTag: csrf.TemplateField(r),
|
||||||
"LabelEmail": "Email",
|
"LabelEmail": "Email",
|
||||||
"LabelPassword": "Password",
|
"LabelPassword": "Password",
|
||||||
"LabelLogin": "Login",
|
"LabelLogin": "Login",
|
||||||
"errors": map[string]string{},
|
"errors": map[string]string{},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
h.logger.Error(err)
|
||||||
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
@ -60,23 +61,23 @@ func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reque
|
||||||
|
|
||||||
// validate input
|
// validate input
|
||||||
decoder := form.NewDecoder()
|
decoder := form.NewDecoder()
|
||||||
err = decoder.Decode(&loginInfo, request.Form)
|
err = decoder.Decode(&loginInfo, r.Form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
h.logger.Error(err)
|
||||||
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err := validate.Struct(&loginInfo)
|
err := validate.Struct(&loginInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors := make(map[string]string)
|
errors := make(map[string]string)
|
||||||
for _, err := range err.(validator.ValidationErrors) {
|
for _, err := range err.(validator.ValidationErrors) {
|
||||||
accept := request.Header.Get("Accept-Language")
|
accept := r.Header.Get("Accept-Language")
|
||||||
errors[err.Field()] = h.lookupErrorMessage(err.Tag(), err.Field(), err.Value(), i18n.NewLocalizer(h.bundle, accept))
|
errors[err.Field()] = h.lookupErrorMessage(err.Tag(), err.Field(), err.Value(), i18n.NewLocalizer(h.bundle, accept))
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.loginTemplate.Lookup("base").Execute(writer, map[string]interface{}{
|
err = h.loginTemplate.Lookup("base").Execute(w, map[string]interface{}{
|
||||||
"Title": "Title",
|
"Title": "Title",
|
||||||
csrf.TemplateTag: csrf.TemplateField(request),
|
csrf.TemplateTag: csrf.TemplateField(r),
|
||||||
"LabelEmail": "Email",
|
"LabelEmail": "Email",
|
||||||
"LabelPassword": "Password",
|
"LabelPassword": "Password",
|
||||||
"LabelLogin": "Login",
|
"LabelLogin": "Login",
|
||||||
|
@ -84,8 +85,8 @@ func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reque
|
||||||
"errors": errors,
|
"errors": errors,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
h.logger.Error(err)
|
||||||
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -103,13 +104,14 @@ func (h *loginHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reque
|
||||||
Subject: &subject,
|
Subject: &subject,
|
||||||
}).WithTimeout(time.Second * 10))
|
}).WithTimeout(time.Second * 10))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic(err)
|
h.logger.Errorf("error getting logout requests: %v", err)
|
||||||
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
writer.Header().Add("Location", *loginRequest.GetPayload().RedirectTo)
|
w.Header().Add("Location", *loginRequest.GetPayload().RedirectTo)
|
||||||
writer.WriteHeader(http.StatusFound)
|
w.WriteHeader(http.StatusFound)
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -118,13 +120,13 @@ func (h *loginHandler) lookupErrorMessage(tag string, field string, value interf
|
||||||
var message *i18n.Message
|
var message *i18n.Message
|
||||||
message, ok := h.messageCatalog[fmt.Sprintf("%s-%s", field, tag)]
|
message, ok := h.messageCatalog[fmt.Sprintf("%s-%s", field, tag)]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Infof("no specific error message %s-%s", field, tag)
|
h.logger.Infof("no specific error message %s-%s", field, tag)
|
||||||
message, ok = h.messageCatalog[tag]
|
message, ok = h.messageCatalog[tag]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Infof("no specific error message %s", tag)
|
h.logger.Infof("no specific error message %s", tag)
|
||||||
message, ok = h.messageCatalog["unknown"]
|
message, ok = h.messageCatalog["unknown"]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Error("no default translation found")
|
h.logger.Error("no default translation found")
|
||||||
return tag
|
return tag
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -137,18 +139,19 @@ func (h *loginHandler) lookupErrorMessage(tag string, field string, value interf
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
h.logger.Error(err)
|
||||||
return tag
|
return tag
|
||||||
}
|
}
|
||||||
return translation
|
return translation
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLoginHandler(ctx context.Context) (*loginHandler, error) {
|
func NewLoginHandler(logger *log.Logger, ctx context.Context) (*loginHandler, error) {
|
||||||
loginTemplate, err := template.ParseFiles("templates/base.html", "templates/login.html")
|
loginTemplate, err := template.ParseFiles("templates/base.html", "templates/login.html")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &loginHandler{
|
return &loginHandler{
|
||||||
|
logger: logger,
|
||||||
loginTemplate: loginTemplate,
|
loginTemplate: loginTemplate,
|
||||||
bundle: ctx.Value(services.CtxI18nBundle).(*i18n.Bundle),
|
bundle: ctx.Value(services.CtxI18nBundle).(*i18n.Bundle),
|
||||||
messageCatalog: ctx.Value(services.CtxI18nCatalog).(map[string]*i18n.Message),
|
messageCatalog: ctx.Value(services.CtxI18nCatalog).(map[string]*i18n.Message),
|
||||||
|
|
58
idp/handlers/logout.go
Normal file
58
idp/handlers/logout.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ory/hydra-client-go/client/admin"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type logoutHandler struct {
|
||||||
|
adminClient *admin.Client
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *logoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
challenge := r.URL.Query().Get("logout_challenge")
|
||||||
|
h.logger.Debugf("received challenge %s\n", challenge)
|
||||||
|
|
||||||
|
logoutRequest, err := h.adminClient.GetLogoutRequest(
|
||||||
|
admin.NewGetLogoutRequestParams().WithLogoutChallenge(challenge).WithTimeout(time.Second * 10))
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Errorf("error getting logout requests: %v", err)
|
||||||
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Debugf("received logout request: %#v", logoutRequest.Payload)
|
||||||
|
|
||||||
|
acceptLogoutRequest, err := h.adminClient.AcceptLogoutRequest(
|
||||||
|
admin.NewAcceptLogoutRequestParams().WithLogoutChallenge(challenge))
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Errorf("error accepting logout: %v", err)
|
||||||
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Location", *acceptLogoutRequest.GetPayload().RedirectTo)
|
||||||
|
w.WriteHeader(http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogoutHandler(logger *log.Logger, ctx context.Context) *logoutHandler {
|
||||||
|
return &logoutHandler{
|
||||||
|
logger: logger,
|
||||||
|
adminClient: ctx.Value(CtxAdminClient).(*admin.Client),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type logoutSuccessHandler struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logoutSuccessHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogoutSuccessHandler() *logoutSuccessHandler {
|
||||||
|
return &logoutSuccessHandler{}
|
||||||
|
}
|
Reference in a new issue