From 9924771531ef7538c1aa366632c3ecb02be982da Mon Sep 17 00:00:00 2001 From: Jan Dittberner Date: Sun, 19 Apr 2020 22:29:58 +0200 Subject: [PATCH] Refactor client I/O into protocol package --- client/io.go | 98 --------------- client/main.go | 150 ++++++++++------------- client/processing/process.go | 20 +++ client/protocol/protocol.go | 229 +++++++++++++++++++++++++++++++++++ shared/io.go | 30 +++-- shared/shared.go | 2 +- signer/main.go | 10 +- 7 files changed, 339 insertions(+), 200 deletions(-) delete mode 100644 client/io.go create mode 100644 client/processing/process.go create mode 100644 client/protocol/protocol.go diff --git a/client/io.go b/client/io.go deleted file mode 100644 index 3c3a2b2..0000000 --- a/client/io.go +++ /dev/null @@ -1,98 +0,0 @@ -package main - -import ( - "errors" - "fmt" - - log "github.com/sirupsen/logrus" - "go.bug.st/serial" - - "git.cacert.org/cacert-gosigner/datastructures" - "git.cacert.org/cacert-gosigner/shared" -) - -func sendHandShake(port serial.Port) error { - log.Debug("Shaking hands ...") - if length, err := port.Write([]byte{shared.HandshakeByte}); err != nil { - return fmt.Errorf("could not write handshake byte: %v", err) - } else { - log.Tracef("wrote %d handshake bytes", length) - } - handShakeResponse := make([]byte, 1) - if length, err := port.Read(handShakeResponse); err != nil { - return fmt.Errorf("failed to read handshake response: %v", err) - } else { - log.Tracef("read %d bytes", length) - } - if handShakeResponse[0] != shared.AckByte { - return fmt.Errorf("invalid handshake response expected 0x10 received %x", handShakeResponse[0]) - } - log.Debug("Handshake successful") - return nil -} - -func receiveResponse(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error) { - header, err := shared.ReceiveBytes(port, 1, 20) - if err != nil { - *errorChan <- err - return - } - if header[0] != shared.HandshakeByte { - *errorChan <- fmt.Errorf("unexpected byte 0x%x expected 0x%x", header, shared.HandshakeByte) - } - log.Tracef("received handshake byte") - - if _, err := (*port).Write([]byte{shared.AckByte}); err != nil { - *errorChan <- errors.New("could not write ACK") - return - } - log.Tracef("sent ACK byte") - - lengthBytes, err := shared.ReceiveBytes(port, 3, 2) - if err != nil { - *errorChan <- err - return - } - blockLength := datastructures.Decode24BitLength(lengthBytes) - log.Tracef("received block length %d", blockLength) - - blockData, err := shared.ReceiveBytes(port, blockLength, 5) - if err != nil { - *errorChan <- err - return - } - log.Tracef("received bytes %v", blockData) - - checkSum, err := shared.ReceiveBytes(port, 1, 2) - if err != nil { - *errorChan <- err - return - } - log.Tracef("received checksum 0x%x", checkSum[0]) - - trailer, err := shared.ReceiveBytes(port, len(shared.MagicTrailer), 2) - if err != nil { - *errorChan <- err - return - } - if string(trailer) != shared.MagicTrailer { - *errorChan <- errors.New("expected trailer bytes not found") - return - } - log.Tracef("received valid trailer bytes") - - if _, err := (*port).Write([]byte{shared.AckByte}); err != nil { - *errorChan <- fmt.Errorf("could not write ACK byte: %v", err) - return - } - log.Tracef("sent ACK byte") - - signerResponse, err := datastructures.SignerResponseFromData(lengthBytes, blockData, checkSum[0]) - if err != nil { - *errorChan <- err - return - } - log.Infof("received response of type %s", signerResponse.Action) - - *responseChan <- *signerResponse -} diff --git a/client/main.go b/client/main.go index 08ac5e2..bb3c61d 100644 --- a/client/main.go +++ b/client/main.go @@ -2,14 +2,19 @@ package main import ( "flag" - "fmt" + "io" + "os" + "os/signal" + "sync" + "syscall" "time" log "github.com/sirupsen/logrus" "go.bug.st/serial" + "git.cacert.org/cacert-gosigner/client/processing" + "git.cacert.org/cacert-gosigner/client/protocol" "git.cacert.org/cacert-gosigner/datastructures" - "git.cacert.org/cacert-gosigner/shared" ) func main() { @@ -27,7 +32,7 @@ func main() { } serialConfig = fillSerialMode(clientConfig) if clientConfig.Debug { - log.SetLevel(log.DebugLevel) + log.SetLevel(log.TraceLevel) } log.Infof("connecting to %s using %+v", clientConfig.SerialAddress, serialConfig) @@ -36,25 +41,64 @@ func main() { log.Fatal(err) } log.Debug("serial port connected") - defer func() { - err := port.Close() - if err != nil { - log.Fatal(err) + + requestChannel := protocol.NewSignerProtocolRequestChannel() + responseChannel := make(chan *datastructures.SignerResponse, 1) + + readWriteCloser := (io.ReadWriteCloser)(port) + protocolHandler := protocol.NewProtocolHandler(requestChannel, &responseChannel, &readWriteCloser) + + cancelChannel := make(chan os.Signal, 1) + signal.Notify(cancelChannel, syscall.SIGTERM, syscall.SIGINT) + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + if err := protocolHandler.HandleSignerProtocol(); err != nil { + log.Errorf("terminating because of %v", err) + close(cancelChannel) } - log.Debug("serial port closed") + wg.Done() }() - errorChannel := make(chan error, 1) - responseChannel := make(chan datastructures.SignerResponse, 1) + go func() { + runMainLoop(requestChannel, &responseChannel) + wg.Done() + }() + + sig := <-cancelChannel + if sig != nil { + log.Infof("caught %+v", sig) + } + if err := protocolHandler.Close(); err != nil { + log.Error(err) + } else { + log.Infof("protocol handler closed") + } + if err := port.Close(); err != nil { + log.Error(err) + } else { + log.Infof("serial port closed") + } + wg.Wait() +} + +func runMainLoop(requestChannel *protocol.SignerProtocolRequestChannel, responseChannel *chan *datastructures.SignerResponse) { crlCheck := 0 log.Debug("starting main loop") + go func() { + for response := range *responseChannel { + if err := processing.Process(response); err != nil { + log.Error(err) + } + } + log.Trace("processing goroutine terminated") + }() + for { - requestChannel := make(chan datastructures.SignerRequest, 1) - - go HandleRequests(&port, &responseChannel, &errorChannel, &requestChannel) - log.Debug("handling GPG database ...") // HandleGPG(&requestChannel) log.Debug("issuing certificates ...") @@ -68,83 +112,13 @@ func main() { // RefreshCRLs(&requestChannel) } - log.Debug("send NUL request to keep connection open") - requestChannel <- *datastructures.NewNulRequest() - - select { - case response := <-responseChannel: - if err := Process(response); err != nil { - log.Error(err) - } - case err := <-errorChannel: - log.Error(err) + if requestChannel.IsClosed() { + return } + log.Debug("send NUL request to keep connection open") + requestChannel.C <- datastructures.NewNulRequest() log.Debug("sleep for 2.7 seconds") time.Sleep(2700 * time.Millisecond) } } - -func Process(response datastructures.SignerResponse) (err error) { - log.Infof("process response of type %s", response.Action) - log.Tracef("process response %v", response) - - switch response.Action { - case datastructures.ActionNul: - log.Trace("received response for NUL request") - return - default: - return fmt.Errorf("unsupported action in response 0x%x", response.Action) - } -} - -func HandleRequests(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error, requestChan *chan datastructures.SignerRequest) { - for { - select { - case request := <-*requestChan: - SendRequest(port, responseChan, errorChan, &request) - } - } -} - -func SendRequest(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error, request *datastructures.SignerRequest) { - log.Tracef("send request %v to serial port %v", *request, *port) - if err := sendHandShake(*port); err != nil { - *errorChan <- err - return - } - - requestBytes := request.Serialize() - if length, err := (*port).Write(requestBytes); err != nil { - *errorChan <- err - return - } else { - log.Tracef("wrote %d request bytes", length) - } - - if length, err := (*port).Write([]byte{datastructures.CalculateXorCheckSum([][]byte{requestBytes})}); err != nil { - *errorChan <- err - return - } else { - log.Tracef("wrote %d checksum bytes", length) - } - - if length, err := (*port).Write([]byte(shared.MagicTrailer)); err != nil { - *errorChan <- err - return - } else { - log.Tracef("wrote %d trailer bytes", length) - } - - header, err := shared.ReceiveBytes(port, 1, 20) - if err != nil { - *errorChan <- err - return - } - if header[0] != shared.AckByte { - *errorChan <- fmt.Errorf("unexpected byte 0x%x expected 0x%x", header, shared.AckByte) - return - } - - receiveResponse(port, responseChan, errorChan) -} diff --git a/client/processing/process.go b/client/processing/process.go new file mode 100644 index 0000000..6263920 --- /dev/null +++ b/client/processing/process.go @@ -0,0 +1,20 @@ +package processing + +import ( + "fmt" + "git.cacert.org/cacert-gosigner/datastructures" + "github.com/sirupsen/logrus" +) + +func Process(response *datastructures.SignerResponse) (err error) { + logrus.Infof("process response of type %s", response.Action) + logrus.Tracef("process response %v", response) + + switch response.Action { + case datastructures.ActionNul: + logrus.Trace("received response for NUL request") + return + default: + return fmt.Errorf("unsupported action in response 0x%x", response.Action) + } +} diff --git a/client/protocol/protocol.go b/client/protocol/protocol.go new file mode 100644 index 0000000..8be88d3 --- /dev/null +++ b/client/protocol/protocol.go @@ -0,0 +1,229 @@ +package protocol + +import ( + "bytes" + "fmt" + "git.cacert.org/cacert-gosigner/datastructures" + "git.cacert.org/cacert-gosigner/shared" + log "github.com/sirupsen/logrus" + "io" + "sync" + "time" +) + +type SignerProtocolHandler interface { + io.Closer + HandleSignerProtocol() error +} + +type SignerProtocolRequestChannel struct { + C chan *datastructures.SignerRequest + closed bool + mutex sync.Mutex +} + +func NewSignerProtocolRequestChannel() *SignerProtocolRequestChannel { + return &SignerProtocolRequestChannel{C: make(chan *datastructures.SignerRequest, 1)} +} + +func (rc *SignerProtocolRequestChannel) SafeClose() { + rc.mutex.Lock() + defer rc.mutex.Unlock() + if !rc.closed { + close(rc.C) + rc.closed = true + } +} + +func (rc *SignerProtocolRequestChannel) IsClosed() bool { + rc.mutex.Lock() + defer rc.mutex.Unlock() + return rc.closed +} + +type protocolHandler struct { + requestChannel *SignerProtocolRequestChannel + responseChannel *chan *datastructures.SignerResponse + serialConnection *io.ReadWriteCloser +} + +type UnExpectedAcknowledgeByte struct { + ResponseByte byte +} +type UnExpectedHandshakeByte struct { + ResponseByte byte +} + +func (e UnExpectedHandshakeByte) Error() string { + return fmt.Sprintf("unexpected handshake byte 0x%x instead of 0x%x", e.ResponseByte, shared.HandshakeByte) +} +func (e UnExpectedAcknowledgeByte) Error() string { + return fmt.Sprintf("unexpected acknowledge byte 0x%x instead of 0x%x", e.ResponseByte, shared.AckByte) +} + +func (ph *protocolHandler) Close() error { + close(*ph.responseChannel) + ph.requestChannel.SafeClose() + return nil +} + +func NewProtocolHandler(requests *SignerProtocolRequestChannel, response *chan *datastructures.SignerResponse, serialConnection *io.ReadWriteCloser) SignerProtocolHandler { + return &protocolHandler{ + requestChannel: requests, + responseChannel: response, + serialConnection: serialConnection, + } +} + +func (ph *protocolHandler) HandleSignerProtocol() error { + for { + select { + case request := <-ph.requestChannel.C: + log.Debugf("handle request %+v", request) + var err error + var lengthBytes, responseBytes *[]byte + var checksum byte + + if err = ph.sendHandshake(); err != nil { + switch err.(type) { + case UnExpectedAcknowledgeByte: + log.Errorf("unexpected handshake byte: 0x%x", err.(UnExpectedAcknowledgeByte).ResponseByte) + // TODO drain input + } + return err + } + requestBytes := request.Serialize() + if err = ph.sendRequest(&requestBytes); err != nil { + return err + } + if err = ph.waitForResponseHandshake(); err != nil { + return err + } + if lengthBytes, responseBytes, checksum, err = ph.readResponse(); err != nil { + return err + } + response, err := datastructures.SignerResponseFromData(*lengthBytes, *responseBytes, checksum) + if err != nil { + return err + } + *ph.responseChannel <- response + } + } +} + +func (ph *protocolHandler) sendHandshake() (err error) { + var bytesWritten, bytesRead int + if bytesWritten, err = (*ph.serialConnection).Write([]byte{shared.HandshakeByte}); err != nil { + return + } else { + log.Tracef("wrote %d bytes of handshake info", bytesWritten) + } + data := make([]byte, 1) + if bytesRead, err = (*ph.serialConnection).Read(data); err != nil { + return + } + log.Tracef("%d bytes read", bytesRead) + if bytesRead != 1 || data[0] != shared.AckByte { + log.Warnf("received invalid handshake byte 0x%x", data[0]) + return UnExpectedAcknowledgeByte{data[0]} + } + return +} + +func (ph *protocolHandler) sendRequest(requestBytes *[]byte) error { + for { + if length, err := (*ph.serialConnection).Write(*requestBytes); err != nil { + return err + } else { + log.Tracef("wrote %d request bytes", length) + } + + if length, err := (*ph.serialConnection).Write([]byte{ + datastructures.CalculateXorCheckSum([][]byte{*requestBytes}), + }); err != nil { + return err + } else { + log.Tracef("wrote %d checksum bytes", length) + } + + if length, err := (*ph.serialConnection).Write([]byte(shared.MagicTrailer)); err != nil { + return err + } else { + log.Tracef("wrote %d trailer bytes", length) + } + header, err := shared.ReceiveBytes(ph.serialConnection, 1, 20*time.Second) + if err != nil { + return err + } + switch header[0] { + case shared.AckByte: + return nil + case shared.ResendByte: + default: + return UnExpectedAcknowledgeByte{header[0]} + } + } +} + +func (ph *protocolHandler) waitForResponseHandshake() (err error) { + data, err := shared.ReceiveBytes(ph.serialConnection, 1, 120*time.Second) + if err != nil { + return err + } + if len(data) != 1 || data[0] != shared.HandshakeByte { + log.Warnf("received invalid handshake byte 0x%x", data[0]) + return UnExpectedHandshakeByte{data[0]} + } + if err = shared.SendByte(ph.serialConnection, shared.AckByte); err != nil { + return + } + + return +} + +func (ph *protocolHandler) readResponse() (*[]byte, *[]byte, byte, error) { + dataLength := -1 + var lengthBuffer = bytes.NewBuffer(make([]byte, 0)) + var byteBuffer = bytes.NewBuffer(make([]byte, 0)) + + for { + readBuffer, err := shared.ReceiveBytes(ph.serialConnection, 100, 5*time.Second) + if err != nil { + return nil, nil, 0, err + } + bytesRead := len(readBuffer) + if bytesRead > 0 { + byteBuffer.Write(readBuffer[0:bytesRead]) + for _, b := range readBuffer { + if lengthBuffer.Len() < 3 { + lengthBuffer.WriteByte(b) + } else { + break + } + } + } + if dataLength < 0 && lengthBuffer.Len() == 3 { + dataLength = datastructures.Decode24BitLength(lengthBuffer.Bytes()) + log.Tracef("expect to read %d data bytes", dataLength) + } + if dataLength == byteBuffer.Len()-4-len(shared.MagicTrailer) { + allBytes := byteBuffer.Bytes() + trailer := string(allBytes[4+dataLength:]) + if trailer != shared.MagicTrailer { + return nil, nil, 0, fmt.Errorf("invalid trailer bytes: %v", trailer) + } + lengthBytes := allBytes[0:3] + dataBytes := allBytes[3 : 3+dataLength] + checkSum := allBytes[3+dataLength] + calculatedChecksum := datastructures.CalculateXorCheckSum([][]byte{lengthBytes, dataBytes}) + if calculatedChecksum != checkSum { + return nil, nil, 0, fmt.Errorf("calculated checksum mismatch 0x%x vs 0x%x", calculatedChecksum, checkSum) + } + if err := shared.SendByte(ph.serialConnection, shared.AckByte); err != nil { + return nil, nil, 0, err + } + + return &lengthBytes, &dataBytes, checkSum, nil + } + } +} diff --git a/shared/io.go b/shared/io.go index 6d463fe..59e6200 100644 --- a/shared/io.go +++ b/shared/io.go @@ -1,31 +1,43 @@ package shared import ( - "errors" - "go.bug.st/serial" + "fmt" + log "github.com/sirupsen/logrus" "io" "time" ) -// receive the requested number of bytes from serial port and stop after the given timeout in seconds -func ReceiveBytes(port *serial.Port, count int, timeout time.Duration) ([]byte, error) { - timeoutCh := time.After(timeout * time.Second) +// receive at maximum the requested number of bytes from serial port and stop after the given timeout +func ReceiveBytes(port *io.ReadWriteCloser, count int, timeout time.Duration) ([]byte, error) { readCh := make(chan []byte, 1) errCh := make(chan error, 1) go func() { data := make([]byte, count) - if _, err := io.ReadAtLeast(*port, data, count); err != nil { + if readBytes, err := (*port).Read(data); err != nil { errCh <- err + } else if readBytes > 0 { + log.Tracef("%d bytes read", readBytes) + readCh <- data[0:readBytes] } else { - readCh <- data + readCh <- make([]byte, 0) } + return }() select { - case <-timeoutCh: - return nil, errors.New("timeout") + case <-time.After(timeout): + return nil, fmt.Errorf("timeout passed %v: %v", timeout) case err := <-errCh: return nil, err case data := <-readCh: return data, nil } } + +func SendByte(port *io.ReadWriteCloser, data byte) error { + if bytesWritten, err := (*port).Write([]byte{data}); err != nil { + return err + } else { + log.Tracef("wrote %d bytes of handshake info", bytesWritten) + } + return nil +} diff --git a/shared/shared.go b/shared/shared.go index 1677f1c..93d22b7 100644 --- a/shared/shared.go +++ b/shared/shared.go @@ -4,4 +4,4 @@ const MagicTrailer = "rie4Ech7" const HandshakeByte = 0x02 const AckByte = 0x10 -const NackByte = 0x11 +const ResendByte = 0x11 diff --git a/signer/main.go b/signer/main.go index 1285c71..8491f69 100644 --- a/signer/main.go +++ b/signer/main.go @@ -4,6 +4,7 @@ import ( "errors" "flag" "fmt" + "io" "time" log "github.com/sirupsen/logrus" @@ -53,7 +54,8 @@ readLoop: commandChan := make(chan datastructures.SignerRequest, 1) errChan := make(chan error, 1) - go Receive(&port, &commandChan, &errChan) + readWriteCloser := (io.ReadWriteCloser)(port) + go Receive(&readWriteCloser, &commandChan, &errChan) select { case command := <-commandChan: @@ -61,7 +63,7 @@ readLoop: if err != nil { log.Printf("ERROR %v\n", err) } else { - _ = SendResponse(&port, response) + _ = SendResponse(&readWriteCloser, response) } case <-timeout: log.Println("timeout in main loop") @@ -76,7 +78,7 @@ readLoop: } // Send a response to the client -func SendResponse(port *serial.Port, response *datastructures.SignerResponse) error { +func SendResponse(port *io.ReadWriteCloser, response *datastructures.SignerResponse) error { if _, err := (*port).Write([]byte{0x02}); err != nil { return err } @@ -136,7 +138,7 @@ func handleNulAction(command datastructures.SignerRequest) (*datastructures.Sign } // Receive a request and generate a request data structure -func Receive(port *serial.Port, commandChan *chan datastructures.SignerRequest, errorChan *chan error) { +func Receive(port *io.ReadWriteCloser, commandChan *chan datastructures.SignerRequest, errorChan *chan error) { header, err := shared.ReceiveBytes(port, 1, 20) if err != nil { *errorChan <- err