Refactor client I/O into protocol package

This commit is contained in:
Jan Dittberner 2020-04-19 22:29:58 +02:00
parent 337e974a26
commit 9924771531
7 changed files with 339 additions and 200 deletions

View file

@ -2,14 +2,19 @@ package main
import (
"flag"
"fmt"
"io"
"os"
"os/signal"
"sync"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"go.bug.st/serial"
"git.cacert.org/cacert-gosigner/client/processing"
"git.cacert.org/cacert-gosigner/client/protocol"
"git.cacert.org/cacert-gosigner/datastructures"
"git.cacert.org/cacert-gosigner/shared"
)
func main() {
@ -27,7 +32,7 @@ func main() {
}
serialConfig = fillSerialMode(clientConfig)
if clientConfig.Debug {
log.SetLevel(log.DebugLevel)
log.SetLevel(log.TraceLevel)
}
log.Infof("connecting to %s using %+v", clientConfig.SerialAddress, serialConfig)
@ -36,25 +41,64 @@ func main() {
log.Fatal(err)
}
log.Debug("serial port connected")
defer func() {
err := port.Close()
if err != nil {
log.Fatal(err)
requestChannel := protocol.NewSignerProtocolRequestChannel()
responseChannel := make(chan *datastructures.SignerResponse, 1)
readWriteCloser := (io.ReadWriteCloser)(port)
protocolHandler := protocol.NewProtocolHandler(requestChannel, &responseChannel, &readWriteCloser)
cancelChannel := make(chan os.Signal, 1)
signal.Notify(cancelChannel, syscall.SIGTERM, syscall.SIGINT)
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
if err := protocolHandler.HandleSignerProtocol(); err != nil {
log.Errorf("terminating because of %v", err)
close(cancelChannel)
}
log.Debug("serial port closed")
wg.Done()
}()
errorChannel := make(chan error, 1)
responseChannel := make(chan datastructures.SignerResponse, 1)
go func() {
runMainLoop(requestChannel, &responseChannel)
wg.Done()
}()
sig := <-cancelChannel
if sig != nil {
log.Infof("caught %+v", sig)
}
if err := protocolHandler.Close(); err != nil {
log.Error(err)
} else {
log.Infof("protocol handler closed")
}
if err := port.Close(); err != nil {
log.Error(err)
} else {
log.Infof("serial port closed")
}
wg.Wait()
}
func runMainLoop(requestChannel *protocol.SignerProtocolRequestChannel, responseChannel *chan *datastructures.SignerResponse) {
crlCheck := 0
log.Debug("starting main loop")
go func() {
for response := range *responseChannel {
if err := processing.Process(response); err != nil {
log.Error(err)
}
}
log.Trace("processing goroutine terminated")
}()
for {
requestChannel := make(chan datastructures.SignerRequest, 1)
go HandleRequests(&port, &responseChannel, &errorChannel, &requestChannel)
log.Debug("handling GPG database ...")
// HandleGPG(&requestChannel)
log.Debug("issuing certificates ...")
@ -68,83 +112,13 @@ func main() {
// RefreshCRLs(&requestChannel)
}
log.Debug("send NUL request to keep connection open")
requestChannel <- *datastructures.NewNulRequest()
select {
case response := <-responseChannel:
if err := Process(response); err != nil {
log.Error(err)
}
case err := <-errorChannel:
log.Error(err)
if requestChannel.IsClosed() {
return
}
log.Debug("send NUL request to keep connection open")
requestChannel.C <- datastructures.NewNulRequest()
log.Debug("sleep for 2.7 seconds")
time.Sleep(2700 * time.Millisecond)
}
}
func Process(response datastructures.SignerResponse) (err error) {
log.Infof("process response of type %s", response.Action)
log.Tracef("process response %v", response)
switch response.Action {
case datastructures.ActionNul:
log.Trace("received response for NUL request")
return
default:
return fmt.Errorf("unsupported action in response 0x%x", response.Action)
}
}
func HandleRequests(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error, requestChan *chan datastructures.SignerRequest) {
for {
select {
case request := <-*requestChan:
SendRequest(port, responseChan, errorChan, &request)
}
}
}
func SendRequest(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error, request *datastructures.SignerRequest) {
log.Tracef("send request %v to serial port %v", *request, *port)
if err := sendHandShake(*port); err != nil {
*errorChan <- err
return
}
requestBytes := request.Serialize()
if length, err := (*port).Write(requestBytes); err != nil {
*errorChan <- err
return
} else {
log.Tracef("wrote %d request bytes", length)
}
if length, err := (*port).Write([]byte{datastructures.CalculateXorCheckSum([][]byte{requestBytes})}); err != nil {
*errorChan <- err
return
} else {
log.Tracef("wrote %d checksum bytes", length)
}
if length, err := (*port).Write([]byte(shared.MagicTrailer)); err != nil {
*errorChan <- err
return
} else {
log.Tracef("wrote %d trailer bytes", length)
}
header, err := shared.ReceiveBytes(port, 1, 20)
if err != nil {
*errorChan <- err
return
}
if header[0] != shared.AckByte {
*errorChan <- fmt.Errorf("unexpected byte 0x%x expected 0x%x", header, shared.AckByte)
return
}
receiveResponse(port, responseChan, errorChan)
}