Implement database access for user information
This commit is contained in:
parent
161ea7fe0c
commit
82918fb782
12 changed files with 298 additions and 118 deletions
|
@ -2,12 +2,15 @@ package handlers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/form/v4"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/gorilla/csrf"
|
||||
"github.com/lestrrat-go/jwx/jwt/openid"
|
||||
"github.com/nicksnyder/go-i18n/v2/i18n"
|
||||
|
@ -16,12 +19,14 @@ import (
|
|||
log "github.com/sirupsen/logrus"
|
||||
|
||||
commonServices "git.cacert.org/oidc_login/common/services"
|
||||
"git.cacert.org/oidc_login/idp/services"
|
||||
)
|
||||
|
||||
type consentHandler struct {
|
||||
adminClient *admin.Client
|
||||
bundle *i18n.Bundle
|
||||
consentTemplate *template.Template
|
||||
context context.Context
|
||||
logger *log.Logger
|
||||
messageCatalog *commonServices.MessageCatalog
|
||||
}
|
||||
|
@ -30,6 +35,31 @@ type ConsentInformation struct {
|
|||
ConsentChecked bool `form:"consent"`
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
Email string `db:"email"`
|
||||
EmailVerified bool `db:"verified"`
|
||||
GivenName string `db:"fname"`
|
||||
MiddleName string `db:"mname"`
|
||||
FamilyName string `db:"lname"`
|
||||
BirthDate mysql.NullTime `db:"dob"`
|
||||
Language string `db:"language"`
|
||||
Modified mysql.NullTime `db:"modified"`
|
||||
}
|
||||
|
||||
func (i *UserInfo) GetFullName() string {
|
||||
nameParts := make([]string, 0)
|
||||
if len(i.GivenName) > 0 {
|
||||
nameParts = append(nameParts, i.GivenName)
|
||||
}
|
||||
if len(i.MiddleName) > 0 {
|
||||
nameParts = append(nameParts, i.MiddleName)
|
||||
}
|
||||
if len(i.FamilyName) > 0 {
|
||||
nameParts = append(nameParts, i.FamilyName)
|
||||
}
|
||||
return strings.Join(nameParts, " ")
|
||||
}
|
||||
|
||||
func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
challenge := r.URL.Query().Get("consent_challenge")
|
||||
h.logger.Debugf("received consent challenge %s", challenge)
|
||||
|
@ -47,26 +77,7 @@ func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
trans := h.messageCatalog.LookupMessage
|
||||
|
||||
// render consent form
|
||||
client := consentData.GetPayload().Client
|
||||
err = h.consentTemplate.Lookup("base").Execute(w, map[string]interface{}{
|
||||
"Title": trans("TitleRequestConsent", nil, localizer),
|
||||
csrf.TemplateTag: csrf.TemplateField(r),
|
||||
"errors": map[string]string{},
|
||||
"client": client,
|
||||
"requestedScope": h.mapRequestedScope(consentData.GetPayload().RequestedScope, localizer),
|
||||
"LabelSubmit": trans("LabelSubmit", nil, localizer),
|
||||
"LabelConsent": trans("LabelConsent", nil, localizer),
|
||||
"IntroMoreInformation": template.HTML(trans("IntroConsentMoreInformation", map[string]interface{}{
|
||||
"client": client.ClientName,
|
||||
"clientLink": client.ClientURI,
|
||||
}, localizer)),
|
||||
"IntroConsentRequested": template.HTML(trans("IntroConsentRequested", map[string]interface{}{
|
||||
"client": client.ClientName,
|
||||
}, localizer)),
|
||||
})
|
||||
h.renderConsentForm(w, r, consentData, err, localizer)
|
||||
break
|
||||
case http.MethodPost:
|
||||
var consentInfo ConsentInformation
|
||||
|
@ -82,46 +93,79 @@ func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
if consentInfo.ConsentChecked {
|
||||
idTokenData := make(map[string]interface{}, 0)
|
||||
|
||||
for _, scope := range consentData.GetPayload().RequestedScope {
|
||||
switch scope {
|
||||
case "email":
|
||||
idTokenData[openid.EmailKey] = "john@theripper.mil"
|
||||
idTokenData[openid.EmailVerifiedKey] = true
|
||||
break
|
||||
case "profile":
|
||||
idTokenData[openid.GivenNameKey] = "John"
|
||||
idTokenData[openid.FamilyNameKey] = "The ripper"
|
||||
idTokenData[openid.MiddleNameKey] = ""
|
||||
idTokenData[openid.NameKey] = "John the Ripper"
|
||||
idTokenData[openid.BirthdateKey] = "1970-01-01"
|
||||
idTokenData[openid.ZoneinfoKey] = "Europe/London"
|
||||
idTokenData[openid.LocaleKey] = "en_UK"
|
||||
idTokenData["https://cacert.localhost/groups"] = []string{"admin", "user"}
|
||||
break
|
||||
}
|
||||
}
|
||||
db := services.GetDb(h.context)
|
||||
|
||||
sessionData := &models.ConsentRequestSession{
|
||||
AccessToken: nil,
|
||||
IDToken: idTokenData,
|
||||
}
|
||||
consentRequest, err := h.adminClient.AcceptConsentRequest(
|
||||
admin.NewAcceptConsentRequestParams().WithConsentChallenge(challenge).WithBody(
|
||||
&models.AcceptConsentRequest{
|
||||
GrantAccessTokenAudience: nil,
|
||||
GrantScope: consentData.GetPayload().RequestedScope,
|
||||
HandledAt: models.NullTime(time.Now()),
|
||||
Remember: true,
|
||||
RememberFor: 86400,
|
||||
Session: sessionData,
|
||||
}).WithTimeout(time.Second * 10))
|
||||
stmt, err := db.PreparexContext(
|
||||
r.Context(),
|
||||
`SELECT email, verified, fname, mname, lname, dob, language, modified
|
||||
FROM users
|
||||
WHERE id = ?
|
||||
AND LOCKED = 0`,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
h.logger.Errorf("error preparing user information SQL: %v", err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Add("Location", *consentRequest.GetPayload().RedirectTo)
|
||||
w.WriteHeader(http.StatusFound)
|
||||
defer func() { _ = stmt.Close() }()
|
||||
|
||||
userInfo := &UserInfo{}
|
||||
|
||||
err = stmt.QueryRowxContext(r.Context(), consentData.GetPayload().Subject).StructScan(userInfo)
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
return
|
||||
case err != nil:
|
||||
h.logger.Errorf("error performing user information SQL: %v", err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
default:
|
||||
for _, scope := range consentData.GetPayload().RequestedScope {
|
||||
switch scope {
|
||||
case "email":
|
||||
idTokenData[openid.EmailKey] = userInfo.Email
|
||||
idTokenData[openid.EmailVerifiedKey] = userInfo.EmailVerified
|
||||
break
|
||||
case "profile":
|
||||
idTokenData[openid.GivenNameKey] = userInfo.GivenName
|
||||
idTokenData[openid.FamilyNameKey] = userInfo.FamilyName
|
||||
idTokenData[openid.MiddleNameKey] = userInfo.MiddleName
|
||||
idTokenData[openid.NameKey] = userInfo.GetFullName()
|
||||
if userInfo.BirthDate.Valid {
|
||||
idTokenData[openid.BirthdateKey] = userInfo.BirthDate.Time.Format("2006-01-02")
|
||||
}
|
||||
idTokenData[openid.LocaleKey] = userInfo.Language
|
||||
idTokenData["https://cacert.localhost/groups"] = []string{"admin", "user"}
|
||||
if userInfo.Modified.Valid {
|
||||
idTokenData[openid.UpdatedAtKey] = userInfo.Modified.Time.Unix()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
sessionData := &models.ConsentRequestSession{
|
||||
AccessToken: nil,
|
||||
IDToken: idTokenData,
|
||||
}
|
||||
consentRequest, err := h.adminClient.AcceptConsentRequest(
|
||||
admin.NewAcceptConsentRequestParams().WithConsentChallenge(challenge).WithBody(
|
||||
&models.AcceptConsentRequest{
|
||||
GrantAccessTokenAudience: nil,
|
||||
GrantScope: consentData.GetPayload().RequestedScope,
|
||||
HandledAt: models.NullTime(time.Now()),
|
||||
Remember: true,
|
||||
RememberFor: 86400,
|
||||
Session: sessionData,
|
||||
}).WithTimeout(time.Second * 10))
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Add("Location", *consentRequest.GetPayload().RedirectTo)
|
||||
w.WriteHeader(http.StatusFound)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
consentRequest, err := h.adminClient.RejectConsentRequest(
|
||||
admin.NewRejectConsentRequestParams().WithConsentChallenge(challenge).WithBody(
|
||||
|
@ -137,6 +181,34 @@ func (h *consentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *consentHandler) renderConsentForm(w http.ResponseWriter, r *http.Request, consentData *admin.GetConsentRequestOK, err error, localizer *i18n.Localizer) {
|
||||
trans := func(id string, values ...map[string]interface{}) string {
|
||||
if len(values) > 0 {
|
||||
return h.messageCatalog.LookupMessage(id, values[0], localizer)
|
||||
}
|
||||
return h.messageCatalog.LookupMessage(id, nil, localizer)
|
||||
}
|
||||
|
||||
// render consent form
|
||||
client := consentData.GetPayload().Client
|
||||
err = h.consentTemplate.Lookup("base").Execute(w, map[string]interface{}{
|
||||
"Title": trans("TitleRequestConsent"),
|
||||
csrf.TemplateTag: csrf.TemplateField(r),
|
||||
"errors": map[string]string{},
|
||||
"client": client,
|
||||
"requestedScope": h.mapRequestedScope(consentData.GetPayload().RequestedScope, localizer),
|
||||
"LabelSubmit": trans("LabelSubmit"),
|
||||
"LabelConsent": trans("LabelConsent"),
|
||||
"IntroMoreInformation": template.HTML(trans("IntroConsentMoreInformation", map[string]interface{}{
|
||||
"client": client.ClientName,
|
||||
"clientLink": client.ClientURI,
|
||||
})),
|
||||
"IntroConsentRequested": template.HTML(trans("IntroConsentRequested", map[string]interface{}{
|
||||
"client": client.ClientName,
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
type scopeWithLabel struct {
|
||||
Name string
|
||||
Label string
|
||||
|
@ -162,6 +234,7 @@ func NewConsentHandler(ctx context.Context, logger *log.Logger) (*consentHandler
|
|||
adminClient: ctx.Value(CtxAdminClient).(*admin.Client),
|
||||
bundle: commonServices.GetI18nBundle(ctx),
|
||||
consentTemplate: consentTemplate,
|
||||
context: ctx,
|
||||
logger: logger,
|
||||
messageCatalog: commonServices.GetMessageCatalog(ctx),
|
||||
}, nil
|
||||
|
|
|
@ -2,6 +2,9 @@ package handlers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"time"
|
||||
|
@ -16,11 +19,13 @@ import (
|
|||
log "github.com/sirupsen/logrus"
|
||||
|
||||
commonServices "git.cacert.org/oidc_login/common/services"
|
||||
"git.cacert.org/oidc_login/idp/services"
|
||||
)
|
||||
|
||||
type loginHandler struct {
|
||||
adminClient *admin.Client
|
||||
bundle *i18n.Bundle
|
||||
context context.Context
|
||||
logger *log.Logger
|
||||
loginTemplate *template.Template
|
||||
messageCatalog *commonServices.MessageCatalog
|
||||
|
@ -47,24 +52,15 @@ func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
var err error
|
||||
challenge := r.URL.Query().Get("login_challenge")
|
||||
h.logger.Debugf("received login challenge %s\n", challenge)
|
||||
accept := r.Header.Get("Accept-Language")
|
||||
localizer := i18n.NewLocalizer(h.bundle, accept)
|
||||
|
||||
validate := validator.New()
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// render login form
|
||||
err = h.loginTemplate.Lookup("base").Execute(w, map[string]interface{}{
|
||||
"Title": "Title",
|
||||
csrf.TemplateTag: csrf.TemplateField(r),
|
||||
"LabelEmail": "Email",
|
||||
"LabelPassword": "Password",
|
||||
"LabelLogin": "Login",
|
||||
"errors": map[string]string{},
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.renderLoginForm(w, r, map[string]string{}, &LoginInformation{}, localizer)
|
||||
break
|
||||
case http.MethodPost:
|
||||
var loginInfo LoginInformation
|
||||
|
@ -84,42 +80,66 @@ func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
accept := r.Header.Get("Accept-Language")
|
||||
errors[err.Field()] = h.messageCatalog.LookupErrorMessage(err.Tag(), err.Field(), err.Value(), i18n.NewLocalizer(h.bundle, accept))
|
||||
}
|
||||
h.renderLoginForm(w, r, errors, &loginInfo, localizer)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.loginTemplate.Lookup("base").Execute(w, map[string]interface{}{
|
||||
"Title": "Title",
|
||||
csrf.TemplateTag: csrf.TemplateField(r),
|
||||
"LabelEmail": "Email",
|
||||
"LabelPassword": "Password",
|
||||
"LabelLogin": "Login",
|
||||
"Email": loginInfo.Email,
|
||||
"errors": errors,
|
||||
})
|
||||
db := services.GetDb(h.context)
|
||||
|
||||
stmt, err := db.PrepareContext(
|
||||
r.Context(),
|
||||
`SELECT id
|
||||
FROM users
|
||||
WHERE email = ?
|
||||
AND password = ?
|
||||
AND locked = 0`,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Errorf("error preparing login SQL: %v", err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer func() { _ = stmt.Close() }()
|
||||
|
||||
// FIXME: replace with a real password hash algorithm
|
||||
passwordHash := sha1.Sum([]byte(loginInfo.Password))
|
||||
password := hex.EncodeToString(passwordHash[:])
|
||||
// FIXME: introduce a real opaque identifier (i.e. a UUID)
|
||||
var userId string
|
||||
// GET user data
|
||||
err = stmt.QueryRowContext(r.Context(), loginInfo.Email, password).Scan(&userId)
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
errors := map[string]string{
|
||||
"Form": h.messageCatalog.LookupMessage(
|
||||
"WrongOrLockedUserOrInvalidPassword",
|
||||
nil,
|
||||
localizer,
|
||||
),
|
||||
}
|
||||
h.renderLoginForm(w, r, errors, &loginInfo, localizer)
|
||||
return
|
||||
case err != nil:
|
||||
h.logger.Errorf("error performing login SQL: %v", err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
default:
|
||||
// finish login and redirect to target
|
||||
loginRequest, err := h.adminClient.AcceptLoginRequest(
|
||||
admin.NewAcceptLoginRequestParams().WithLoginChallenge(challenge).WithBody(&models.AcceptLoginRequest{
|
||||
Acr: string(Password),
|
||||
Remember: true,
|
||||
RememberFor: 0,
|
||||
Subject: &userId,
|
||||
}).WithTimeout(time.Second * 10))
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
h.logger.Errorf("error getting login request: %#v", err)
|
||||
http.Error(w, err.Error(), err.(*runtime.APIError).Code)
|
||||
return
|
||||
}
|
||||
return
|
||||
w.Header().Add("Location", *loginRequest.GetPayload().RedirectTo)
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
|
||||
// GET user data
|
||||
// finish login and redirect to target
|
||||
// TODO: get or generate a user id
|
||||
subject := "a-user-with-an-id"
|
||||
loginRequest, err := h.adminClient.AcceptLoginRequest(
|
||||
admin.NewAcceptLoginRequestParams().WithLoginChallenge(challenge).WithBody(&models.AcceptLoginRequest{
|
||||
Acr: string(NoCredentials),
|
||||
Remember: true,
|
||||
RememberFor: 0,
|
||||
Subject: &subject,
|
||||
}).WithTimeout(time.Second * 10))
|
||||
if err != nil {
|
||||
h.logger.Errorf("error getting login request: %#v", err)
|
||||
http.Error(w, err.Error(), err.(*runtime.APIError).Code)
|
||||
return
|
||||
}
|
||||
w.Header().Add("Location", *loginRequest.GetPayload().RedirectTo)
|
||||
w.WriteHeader(http.StatusFound)
|
||||
break
|
||||
default:
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
|
@ -127,6 +147,27 @@ func (h *loginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *loginHandler) renderLoginForm(w http.ResponseWriter, r *http.Request, errors map[string]string, info *LoginInformation, localizer *i18n.Localizer) {
|
||||
trans := func(label string) string {
|
||||
return h.messageCatalog.LookupMessage(label, nil, localizer)
|
||||
}
|
||||
|
||||
err := h.loginTemplate.Lookup("base").Execute(w, map[string]interface{}{
|
||||
"Title": trans("LoginTitle"),
|
||||
csrf.TemplateTag: csrf.TemplateField(r),
|
||||
"LabelEmail": trans("LabelEmail"),
|
||||
"LabelPassword": trans("LabelPassword"),
|
||||
"LabelLogin": trans("LabelLogin"),
|
||||
"Email": info.Email,
|
||||
"errors": errors,
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func NewLoginHandler(ctx context.Context, logger *log.Logger) (*loginHandler, error) {
|
||||
loginTemplate, err := template.ParseFiles(
|
||||
"templates/idp/base.gohtml", "templates/idp/login.gohtml")
|
||||
|
@ -136,6 +177,7 @@ func NewLoginHandler(ctx context.Context, logger *log.Logger) (*loginHandler, er
|
|||
return &loginHandler{
|
||||
adminClient: ctx.Value(CtxAdminClient).(*admin.Client),
|
||||
bundle: commonServices.GetI18nBundle(ctx),
|
||||
context: ctx,
|
||||
logger: logger,
|
||||
loginTemplate: loginTemplate,
|
||||
messageCatalog: commonServices.GetMessageCatalog(ctx),
|
||||
|
|
|
@ -39,7 +39,7 @@ func (h *logoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
|
||||
func NewLogoutHandler(logger *log.Logger, ctx context.Context) *logoutHandler {
|
||||
func NewLogoutHandler(ctx context.Context, logger *log.Logger) *logoutHandler {
|
||||
return &logoutHandler{
|
||||
logger: logger,
|
||||
adminClient: ctx.Value(CtxAdminClient).(*admin.Client),
|
||||
|
|
46
idp/services/database.go
Normal file
46
idp/services/database.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
type dbContextKey int
|
||||
|
||||
const (
|
||||
ctxDbConnection dbContextKey = iota
|
||||
)
|
||||
|
||||
type DatabaseParams struct {
|
||||
ConnMaxLifeTime time.Duration
|
||||
DSN string
|
||||
MaxOpenConnections int
|
||||
MaxIdleConnections int
|
||||
}
|
||||
|
||||
func NewDatabaseParams(dsn string) *DatabaseParams {
|
||||
return &DatabaseParams{
|
||||
DSN: dsn,
|
||||
ConnMaxLifeTime: time.Minute * 3,
|
||||
MaxOpenConnections: 10,
|
||||
MaxIdleConnections: 10,
|
||||
}
|
||||
}
|
||||
|
||||
func InitDatabase(ctx context.Context, params *DatabaseParams) (context.Context, error) {
|
||||
db, err := sqlx.Connect("mysql", params.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetConnMaxLifetime(params.ConnMaxLifeTime)
|
||||
db.SetMaxOpenConns(params.MaxOpenConnections)
|
||||
db.SetMaxIdleConns(params.MaxIdleConnections)
|
||||
return context.WithValue(ctx, ctxDbConnection, db), nil
|
||||
}
|
||||
|
||||
func GetDb(ctx context.Context) *sqlx.DB {
|
||||
return ctx.Value(ctxDbConnection).(*sqlx.DB)
|
||||
}
|
|
@ -66,5 +66,9 @@ func AddMessages(ctx context.Context) {
|
|||
ID: "Scope-email-Description",
|
||||
Other: "Access your primary email address.",
|
||||
}
|
||||
messages["WrongOrLockedUserOrInvalidPassword"] = &i18n.Message{
|
||||
ID: "WrongOrLockedUserOrInvalidPassword",
|
||||
Other: "You entered an invalid username or password or your account has been locked.",
|
||||
}
|
||||
services.GetMessageCatalog(ctx).AddMessages(messages)
|
||||
}
|
||||
|
|
Reference in a new issue