Refactor client I/O into protocol package

This commit is contained in:
Jan Dittberner 2020-04-19 22:29:58 +02:00
parent 337e974a26
commit 9924771531
7 changed files with 339 additions and 200 deletions

View file

@ -1,98 +0,0 @@
package main
import (
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"go.bug.st/serial"
"git.cacert.org/cacert-gosigner/datastructures"
"git.cacert.org/cacert-gosigner/shared"
)
func sendHandShake(port serial.Port) error {
log.Debug("Shaking hands ...")
if length, err := port.Write([]byte{shared.HandshakeByte}); err != nil {
return fmt.Errorf("could not write handshake byte: %v", err)
} else {
log.Tracef("wrote %d handshake bytes", length)
}
handShakeResponse := make([]byte, 1)
if length, err := port.Read(handShakeResponse); err != nil {
return fmt.Errorf("failed to read handshake response: %v", err)
} else {
log.Tracef("read %d bytes", length)
}
if handShakeResponse[0] != shared.AckByte {
return fmt.Errorf("invalid handshake response expected 0x10 received %x", handShakeResponse[0])
}
log.Debug("Handshake successful")
return nil
}
func receiveResponse(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error) {
header, err := shared.ReceiveBytes(port, 1, 20)
if err != nil {
*errorChan <- err
return
}
if header[0] != shared.HandshakeByte {
*errorChan <- fmt.Errorf("unexpected byte 0x%x expected 0x%x", header, shared.HandshakeByte)
}
log.Tracef("received handshake byte")
if _, err := (*port).Write([]byte{shared.AckByte}); err != nil {
*errorChan <- errors.New("could not write ACK")
return
}
log.Tracef("sent ACK byte")
lengthBytes, err := shared.ReceiveBytes(port, 3, 2)
if err != nil {
*errorChan <- err
return
}
blockLength := datastructures.Decode24BitLength(lengthBytes)
log.Tracef("received block length %d", blockLength)
blockData, err := shared.ReceiveBytes(port, blockLength, 5)
if err != nil {
*errorChan <- err
return
}
log.Tracef("received bytes %v", blockData)
checkSum, err := shared.ReceiveBytes(port, 1, 2)
if err != nil {
*errorChan <- err
return
}
log.Tracef("received checksum 0x%x", checkSum[0])
trailer, err := shared.ReceiveBytes(port, len(shared.MagicTrailer), 2)
if err != nil {
*errorChan <- err
return
}
if string(trailer) != shared.MagicTrailer {
*errorChan <- errors.New("expected trailer bytes not found")
return
}
log.Tracef("received valid trailer bytes")
if _, err := (*port).Write([]byte{shared.AckByte}); err != nil {
*errorChan <- fmt.Errorf("could not write ACK byte: %v", err)
return
}
log.Tracef("sent ACK byte")
signerResponse, err := datastructures.SignerResponseFromData(lengthBytes, blockData, checkSum[0])
if err != nil {
*errorChan <- err
return
}
log.Infof("received response of type %s", signerResponse.Action)
*responseChan <- *signerResponse
}

View file

@ -2,14 +2,19 @@ package main
import ( import (
"flag" "flag"
"fmt" "io"
"os"
"os/signal"
"sync"
"syscall"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.bug.st/serial" "go.bug.st/serial"
"git.cacert.org/cacert-gosigner/client/processing"
"git.cacert.org/cacert-gosigner/client/protocol"
"git.cacert.org/cacert-gosigner/datastructures" "git.cacert.org/cacert-gosigner/datastructures"
"git.cacert.org/cacert-gosigner/shared"
) )
func main() { func main() {
@ -27,7 +32,7 @@ func main() {
} }
serialConfig = fillSerialMode(clientConfig) serialConfig = fillSerialMode(clientConfig)
if clientConfig.Debug { if clientConfig.Debug {
log.SetLevel(log.DebugLevel) log.SetLevel(log.TraceLevel)
} }
log.Infof("connecting to %s using %+v", clientConfig.SerialAddress, serialConfig) log.Infof("connecting to %s using %+v", clientConfig.SerialAddress, serialConfig)
@ -36,25 +41,64 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
log.Debug("serial port connected") log.Debug("serial port connected")
defer func() {
err := port.Close() requestChannel := protocol.NewSignerProtocolRequestChannel()
if err != nil { responseChannel := make(chan *datastructures.SignerResponse, 1)
log.Fatal(err)
readWriteCloser := (io.ReadWriteCloser)(port)
protocolHandler := protocol.NewProtocolHandler(requestChannel, &responseChannel, &readWriteCloser)
cancelChannel := make(chan os.Signal, 1)
signal.Notify(cancelChannel, syscall.SIGTERM, syscall.SIGINT)
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
if err := protocolHandler.HandleSignerProtocol(); err != nil {
log.Errorf("terminating because of %v", err)
close(cancelChannel)
} }
log.Debug("serial port closed") wg.Done()
}() }()
errorChannel := make(chan error, 1) go func() {
responseChannel := make(chan datastructures.SignerResponse, 1) runMainLoop(requestChannel, &responseChannel)
wg.Done()
}()
sig := <-cancelChannel
if sig != nil {
log.Infof("caught %+v", sig)
}
if err := protocolHandler.Close(); err != nil {
log.Error(err)
} else {
log.Infof("protocol handler closed")
}
if err := port.Close(); err != nil {
log.Error(err)
} else {
log.Infof("serial port closed")
}
wg.Wait()
}
func runMainLoop(requestChannel *protocol.SignerProtocolRequestChannel, responseChannel *chan *datastructures.SignerResponse) {
crlCheck := 0 crlCheck := 0
log.Debug("starting main loop") log.Debug("starting main loop")
go func() {
for response := range *responseChannel {
if err := processing.Process(response); err != nil {
log.Error(err)
}
}
log.Trace("processing goroutine terminated")
}()
for { for {
requestChannel := make(chan datastructures.SignerRequest, 1)
go HandleRequests(&port, &responseChannel, &errorChannel, &requestChannel)
log.Debug("handling GPG database ...") log.Debug("handling GPG database ...")
// HandleGPG(&requestChannel) // HandleGPG(&requestChannel)
log.Debug("issuing certificates ...") log.Debug("issuing certificates ...")
@ -68,83 +112,13 @@ func main() {
// RefreshCRLs(&requestChannel) // RefreshCRLs(&requestChannel)
} }
if requestChannel.IsClosed() {
return
}
log.Debug("send NUL request to keep connection open") log.Debug("send NUL request to keep connection open")
requestChannel <- *datastructures.NewNulRequest() requestChannel.C <- datastructures.NewNulRequest()
select {
case response := <-responseChannel:
if err := Process(response); err != nil {
log.Error(err)
}
case err := <-errorChannel:
log.Error(err)
}
log.Debug("sleep for 2.7 seconds") log.Debug("sleep for 2.7 seconds")
time.Sleep(2700 * time.Millisecond) time.Sleep(2700 * time.Millisecond)
} }
} }
func Process(response datastructures.SignerResponse) (err error) {
log.Infof("process response of type %s", response.Action)
log.Tracef("process response %v", response)
switch response.Action {
case datastructures.ActionNul:
log.Trace("received response for NUL request")
return
default:
return fmt.Errorf("unsupported action in response 0x%x", response.Action)
}
}
func HandleRequests(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error, requestChan *chan datastructures.SignerRequest) {
for {
select {
case request := <-*requestChan:
SendRequest(port, responseChan, errorChan, &request)
}
}
}
func SendRequest(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error, request *datastructures.SignerRequest) {
log.Tracef("send request %v to serial port %v", *request, *port)
if err := sendHandShake(*port); err != nil {
*errorChan <- err
return
}
requestBytes := request.Serialize()
if length, err := (*port).Write(requestBytes); err != nil {
*errorChan <- err
return
} else {
log.Tracef("wrote %d request bytes", length)
}
if length, err := (*port).Write([]byte{datastructures.CalculateXorCheckSum([][]byte{requestBytes})}); err != nil {
*errorChan <- err
return
} else {
log.Tracef("wrote %d checksum bytes", length)
}
if length, err := (*port).Write([]byte(shared.MagicTrailer)); err != nil {
*errorChan <- err
return
} else {
log.Tracef("wrote %d trailer bytes", length)
}
header, err := shared.ReceiveBytes(port, 1, 20)
if err != nil {
*errorChan <- err
return
}
if header[0] != shared.AckByte {
*errorChan <- fmt.Errorf("unexpected byte 0x%x expected 0x%x", header, shared.AckByte)
return
}
receiveResponse(port, responseChan, errorChan)
}

View file

@ -0,0 +1,20 @@
package processing
import (
"fmt"
"git.cacert.org/cacert-gosigner/datastructures"
"github.com/sirupsen/logrus"
)
func Process(response *datastructures.SignerResponse) (err error) {
logrus.Infof("process response of type %s", response.Action)
logrus.Tracef("process response %v", response)
switch response.Action {
case datastructures.ActionNul:
logrus.Trace("received response for NUL request")
return
default:
return fmt.Errorf("unsupported action in response 0x%x", response.Action)
}
}

229
client/protocol/protocol.go Normal file
View file

@ -0,0 +1,229 @@
package protocol
import (
"bytes"
"fmt"
"git.cacert.org/cacert-gosigner/datastructures"
"git.cacert.org/cacert-gosigner/shared"
log "github.com/sirupsen/logrus"
"io"
"sync"
"time"
)
type SignerProtocolHandler interface {
io.Closer
HandleSignerProtocol() error
}
type SignerProtocolRequestChannel struct {
C chan *datastructures.SignerRequest
closed bool
mutex sync.Mutex
}
func NewSignerProtocolRequestChannel() *SignerProtocolRequestChannel {
return &SignerProtocolRequestChannel{C: make(chan *datastructures.SignerRequest, 1)}
}
func (rc *SignerProtocolRequestChannel) SafeClose() {
rc.mutex.Lock()
defer rc.mutex.Unlock()
if !rc.closed {
close(rc.C)
rc.closed = true
}
}
func (rc *SignerProtocolRequestChannel) IsClosed() bool {
rc.mutex.Lock()
defer rc.mutex.Unlock()
return rc.closed
}
type protocolHandler struct {
requestChannel *SignerProtocolRequestChannel
responseChannel *chan *datastructures.SignerResponse
serialConnection *io.ReadWriteCloser
}
type UnExpectedAcknowledgeByte struct {
ResponseByte byte
}
type UnExpectedHandshakeByte struct {
ResponseByte byte
}
func (e UnExpectedHandshakeByte) Error() string {
return fmt.Sprintf("unexpected handshake byte 0x%x instead of 0x%x", e.ResponseByte, shared.HandshakeByte)
}
func (e UnExpectedAcknowledgeByte) Error() string {
return fmt.Sprintf("unexpected acknowledge byte 0x%x instead of 0x%x", e.ResponseByte, shared.AckByte)
}
func (ph *protocolHandler) Close() error {
close(*ph.responseChannel)
ph.requestChannel.SafeClose()
return nil
}
func NewProtocolHandler(requests *SignerProtocolRequestChannel, response *chan *datastructures.SignerResponse, serialConnection *io.ReadWriteCloser) SignerProtocolHandler {
return &protocolHandler{
requestChannel: requests,
responseChannel: response,
serialConnection: serialConnection,
}
}
func (ph *protocolHandler) HandleSignerProtocol() error {
for {
select {
case request := <-ph.requestChannel.C:
log.Debugf("handle request %+v", request)
var err error
var lengthBytes, responseBytes *[]byte
var checksum byte
if err = ph.sendHandshake(); err != nil {
switch err.(type) {
case UnExpectedAcknowledgeByte:
log.Errorf("unexpected handshake byte: 0x%x", err.(UnExpectedAcknowledgeByte).ResponseByte)
// TODO drain input
}
return err
}
requestBytes := request.Serialize()
if err = ph.sendRequest(&requestBytes); err != nil {
return err
}
if err = ph.waitForResponseHandshake(); err != nil {
return err
}
if lengthBytes, responseBytes, checksum, err = ph.readResponse(); err != nil {
return err
}
response, err := datastructures.SignerResponseFromData(*lengthBytes, *responseBytes, checksum)
if err != nil {
return err
}
*ph.responseChannel <- response
}
}
}
func (ph *protocolHandler) sendHandshake() (err error) {
var bytesWritten, bytesRead int
if bytesWritten, err = (*ph.serialConnection).Write([]byte{shared.HandshakeByte}); err != nil {
return
} else {
log.Tracef("wrote %d bytes of handshake info", bytesWritten)
}
data := make([]byte, 1)
if bytesRead, err = (*ph.serialConnection).Read(data); err != nil {
return
}
log.Tracef("%d bytes read", bytesRead)
if bytesRead != 1 || data[0] != shared.AckByte {
log.Warnf("received invalid handshake byte 0x%x", data[0])
return UnExpectedAcknowledgeByte{data[0]}
}
return
}
func (ph *protocolHandler) sendRequest(requestBytes *[]byte) error {
for {
if length, err := (*ph.serialConnection).Write(*requestBytes); err != nil {
return err
} else {
log.Tracef("wrote %d request bytes", length)
}
if length, err := (*ph.serialConnection).Write([]byte{
datastructures.CalculateXorCheckSum([][]byte{*requestBytes}),
}); err != nil {
return err
} else {
log.Tracef("wrote %d checksum bytes", length)
}
if length, err := (*ph.serialConnection).Write([]byte(shared.MagicTrailer)); err != nil {
return err
} else {
log.Tracef("wrote %d trailer bytes", length)
}
header, err := shared.ReceiveBytes(ph.serialConnection, 1, 20*time.Second)
if err != nil {
return err
}
switch header[0] {
case shared.AckByte:
return nil
case shared.ResendByte:
default:
return UnExpectedAcknowledgeByte{header[0]}
}
}
}
func (ph *protocolHandler) waitForResponseHandshake() (err error) {
data, err := shared.ReceiveBytes(ph.serialConnection, 1, 120*time.Second)
if err != nil {
return err
}
if len(data) != 1 || data[0] != shared.HandshakeByte {
log.Warnf("received invalid handshake byte 0x%x", data[0])
return UnExpectedHandshakeByte{data[0]}
}
if err = shared.SendByte(ph.serialConnection, shared.AckByte); err != nil {
return
}
return
}
func (ph *protocolHandler) readResponse() (*[]byte, *[]byte, byte, error) {
dataLength := -1
var lengthBuffer = bytes.NewBuffer(make([]byte, 0))
var byteBuffer = bytes.NewBuffer(make([]byte, 0))
for {
readBuffer, err := shared.ReceiveBytes(ph.serialConnection, 100, 5*time.Second)
if err != nil {
return nil, nil, 0, err
}
bytesRead := len(readBuffer)
if bytesRead > 0 {
byteBuffer.Write(readBuffer[0:bytesRead])
for _, b := range readBuffer {
if lengthBuffer.Len() < 3 {
lengthBuffer.WriteByte(b)
} else {
break
}
}
}
if dataLength < 0 && lengthBuffer.Len() == 3 {
dataLength = datastructures.Decode24BitLength(lengthBuffer.Bytes())
log.Tracef("expect to read %d data bytes", dataLength)
}
if dataLength == byteBuffer.Len()-4-len(shared.MagicTrailer) {
allBytes := byteBuffer.Bytes()
trailer := string(allBytes[4+dataLength:])
if trailer != shared.MagicTrailer {
return nil, nil, 0, fmt.Errorf("invalid trailer bytes: %v", trailer)
}
lengthBytes := allBytes[0:3]
dataBytes := allBytes[3 : 3+dataLength]
checkSum := allBytes[3+dataLength]
calculatedChecksum := datastructures.CalculateXorCheckSum([][]byte{lengthBytes, dataBytes})
if calculatedChecksum != checkSum {
return nil, nil, 0, fmt.Errorf("calculated checksum mismatch 0x%x vs 0x%x", calculatedChecksum, checkSum)
}
if err := shared.SendByte(ph.serialConnection, shared.AckByte); err != nil {
return nil, nil, 0, err
}
return &lengthBytes, &dataBytes, checkSum, nil
}
}
}

View file

@ -1,31 +1,43 @@
package shared package shared
import ( import (
"errors" "fmt"
"go.bug.st/serial" log "github.com/sirupsen/logrus"
"io" "io"
"time" "time"
) )
// receive the requested number of bytes from serial port and stop after the given timeout in seconds // receive at maximum the requested number of bytes from serial port and stop after the given timeout
func ReceiveBytes(port *serial.Port, count int, timeout time.Duration) ([]byte, error) { func ReceiveBytes(port *io.ReadWriteCloser, count int, timeout time.Duration) ([]byte, error) {
timeoutCh := time.After(timeout * time.Second)
readCh := make(chan []byte, 1) readCh := make(chan []byte, 1)
errCh := make(chan error, 1) errCh := make(chan error, 1)
go func() { go func() {
data := make([]byte, count) data := make([]byte, count)
if _, err := io.ReadAtLeast(*port, data, count); err != nil { if readBytes, err := (*port).Read(data); err != nil {
errCh <- err errCh <- err
} else if readBytes > 0 {
log.Tracef("%d bytes read", readBytes)
readCh <- data[0:readBytes]
} else { } else {
readCh <- data readCh <- make([]byte, 0)
} }
return
}() }()
select { select {
case <-timeoutCh: case <-time.After(timeout):
return nil, errors.New("timeout") return nil, fmt.Errorf("timeout passed %v: %v", timeout)
case err := <-errCh: case err := <-errCh:
return nil, err return nil, err
case data := <-readCh: case data := <-readCh:
return data, nil return data, nil
} }
} }
func SendByte(port *io.ReadWriteCloser, data byte) error {
if bytesWritten, err := (*port).Write([]byte{data}); err != nil {
return err
} else {
log.Tracef("wrote %d bytes of handshake info", bytesWritten)
}
return nil
}

View file

@ -4,4 +4,4 @@ const MagicTrailer = "rie4Ech7"
const HandshakeByte = 0x02 const HandshakeByte = 0x02
const AckByte = 0x10 const AckByte = 0x10
const NackByte = 0x11 const ResendByte = 0x11

View file

@ -4,6 +4,7 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -53,7 +54,8 @@ readLoop:
commandChan := make(chan datastructures.SignerRequest, 1) commandChan := make(chan datastructures.SignerRequest, 1)
errChan := make(chan error, 1) errChan := make(chan error, 1)
go Receive(&port, &commandChan, &errChan) readWriteCloser := (io.ReadWriteCloser)(port)
go Receive(&readWriteCloser, &commandChan, &errChan)
select { select {
case command := <-commandChan: case command := <-commandChan:
@ -61,7 +63,7 @@ readLoop:
if err != nil { if err != nil {
log.Printf("ERROR %v\n", err) log.Printf("ERROR %v\n", err)
} else { } else {
_ = SendResponse(&port, response) _ = SendResponse(&readWriteCloser, response)
} }
case <-timeout: case <-timeout:
log.Println("timeout in main loop") log.Println("timeout in main loop")
@ -76,7 +78,7 @@ readLoop:
} }
// Send a response to the client // Send a response to the client
func SendResponse(port *serial.Port, response *datastructures.SignerResponse) error { func SendResponse(port *io.ReadWriteCloser, response *datastructures.SignerResponse) error {
if _, err := (*port).Write([]byte{0x02}); err != nil { if _, err := (*port).Write([]byte{0x02}); err != nil {
return err return err
} }
@ -136,7 +138,7 @@ func handleNulAction(command datastructures.SignerRequest) (*datastructures.Sign
} }
// Receive a request and generate a request data structure // Receive a request and generate a request data structure
func Receive(port *serial.Port, commandChan *chan datastructures.SignerRequest, errorChan *chan error) { func Receive(port *io.ReadWriteCloser, commandChan *chan datastructures.SignerRequest, errorChan *chan error) {
header, err := shared.ReceiveBytes(port, 1, 20) header, err := shared.ReceiveBytes(port, 1, 20)
if err != nil { if err != nil {
*errorChan <- err *errorChan <- err