package protocol import ( "bytes" "errors" "fmt" "io" "sync" "time" log "github.com/sirupsen/logrus" "git.cacert.org/cacert-gosigner/datastructures" "git.cacert.org/cacert-gosigner/shared" ) const ( waitForHeader = 20 waitForData = 5 waitForHandShake = 120 bufferSize = 2048 ) type SignerProtocolHandler interface { io.Closer HandleSignerProtocol() error } type SignerProtocolConfig struct { BufferSize int } func NewSignerProtocolConfig() *SignerProtocolConfig { return &SignerProtocolConfig{BufferSize: bufferSize} } type SignerProtocolRequestChannel struct { C chan *datastructures.SignerRequest closed bool mutex sync.Mutex } func NewSignerProtocolRequestChannel() *SignerProtocolRequestChannel { return &SignerProtocolRequestChannel{C: make(chan *datastructures.SignerRequest, 1)} } func (rc *SignerProtocolRequestChannel) SafeClose() { rc.mutex.Lock() defer rc.mutex.Unlock() if !rc.closed { close(rc.C) rc.closed = true } } func (rc *SignerProtocolRequestChannel) IsClosed() bool { rc.mutex.Lock() defer rc.mutex.Unlock() return rc.closed } type protocolHandler struct { requestChannel *SignerProtocolRequestChannel responseChannel *chan *datastructures.SignerResponse serialConnection io.ReadWriteCloser config *SignerProtocolConfig } type UnExpectedAcknowledgeByte struct { ResponseByte byte } type UnExpectedHandshakeByte struct { ResponseByte byte } func (e UnExpectedHandshakeByte) Error() string { return fmt.Sprintf("unexpected handshake byte 0x%x instead of 0x%x", e.ResponseByte, shared.HandshakeByte) } func (e UnExpectedAcknowledgeByte) Error() string { return fmt.Sprintf("unexpected acknowledge byte 0x%x instead of 0x%x", e.ResponseByte, shared.AckByte) } func (ph *protocolHandler) Close() error { close(*ph.responseChannel) ph.requestChannel.SafeClose() return nil } func NewProtocolHandler( requests *SignerProtocolRequestChannel, response *chan *datastructures.SignerResponse, serialConnection io.ReadWriteCloser, config *SignerProtocolConfig, ) SignerProtocolHandler { return &protocolHandler{ requestChannel: requests, responseChannel: response, serialConnection: serialConnection, config: config, } } func (ph *protocolHandler) HandleSignerProtocol() error { for request := range ph.requestChannel.C { log.Tracef("handle request %+v", request) var ( err error lengthBytes, responseBytes *[]byte checksum byte ) if err = ph.sendHandshake(); err != nil { var e *UnExpectedHandshakeByte if errors.As(err, &e) { log.Errorf("unexpected handshake byte: 0x%x", e.ResponseByte) } return err } requestBytes := request.Serialize() if err = ph.sendRequest(&requestBytes); err != nil { return err } if err = ph.waitForResponseHandshake(); err != nil { return err } if lengthBytes, responseBytes, checksum, err = ph.readResponse(); err != nil { return err } response, err := datastructures.SignerResponseFromData(*lengthBytes, *responseBytes, checksum) if err != nil { return fmt.Errorf("could not create response: %w", err) } *ph.responseChannel <- response } return nil } func (ph *protocolHandler) sendHandshake() error { var ( bytesWritten, bytesRead int err error ) data := make([]byte, 0) bytesWritten, err = ph.serialConnection.Write([]byte{shared.HandshakeByte}) if err != nil { return fmt.Errorf("could not send handshake byte: %w", err) } log.Tracef("wrote %d bytes of handshake info", bytesWritten) bytesRead, err = ph.serialConnection.Read(data) if err != nil { return fmt.Errorf("could not receieve ACK byte: %w", err) } log.Tracef("%d bytes read", bytesRead) if bytesRead != 1 || data[0] != shared.AckByte { log.Warnf("received invalid handshake byte 0x%x", data[0]) return UnExpectedAcknowledgeByte{data[0]} } return nil } func (ph *protocolHandler) sendRequest(requestBytes *[]byte) error { var ( n int err error ) for { n, err = ph.serialConnection.Write(*requestBytes) if err != nil { return fmt.Errorf("could not send request bytes: %w", err) } log.Tracef("wrote %d request bytes", n) n, err = ph.serialConnection.Write([]byte{ datastructures.CalculateXorCheckSum([][]byte{*requestBytes}), }) if err != nil { return fmt.Errorf("could not send checksum byte: %w", err) } log.Tracef("wrote %d checksum bytes", n) n, err = ph.serialConnection.Write([]byte(shared.MagicTrailer)) if err != nil { return fmt.Errorf("could not send trailer bytes: %w", err) } log.Tracef("wrote %d trailer bytes", n) header, err := shared.ReceiveBytes(ph.serialConnection, 1, waitForHeader*time.Second) if err != nil { return fmt.Errorf("could not read header bytes: %w", err) } switch header[0] { case shared.AckByte: return nil case shared.ResendByte: default: return UnExpectedAcknowledgeByte{header[0]} } } } func (ph *protocolHandler) waitForResponseHandshake() error { data, err := shared.ReceiveBytes(ph.serialConnection, 1, waitForHandShake*time.Second) if err != nil { return fmt.Errorf("could not receive handshake byte: %w", err) } if len(data) != 1 || data[0] != shared.HandshakeByte { log.Warnf("received invalid handshake byte 0x%x", data[0]) return UnExpectedHandshakeByte{data[0]} } if err = shared.SendBytes(ph.serialConnection, []byte{shared.AckByte}); err != nil { return fmt.Errorf("could not send ACK: %w", err) } return nil } func (ph *protocolHandler) readResponse() (*[]byte, *[]byte, byte, error) { dataLength := -1 var ( lengthBuffer = bytes.NewBuffer(make([]byte, 0)) byteBuffer = bytes.NewBuffer(make([]byte, 0)) ) for { readBuffer, err := shared.ReceiveBytes(ph.serialConnection, ph.config.BufferSize, waitForData*time.Second) if err != nil { return nil, nil, 0, fmt.Errorf("could not receive response: %w", err) } bytesRead := len(readBuffer) if bytesRead > 0 { byteBuffer.Write(readBuffer[0:bytesRead]) for _, b := range readBuffer { if lengthBuffer.Len() < shared.LengthFieldSize { lengthBuffer.WriteByte(b) } else { break } } } if dataLength < 0 && lengthBuffer.Len() == shared.LengthFieldSize { dataLength = datastructures.Decode24BitLength(lengthBuffer.Bytes()) log.Tracef("expect to read %d data bytes", dataLength) } trailerOffset := shared.LengthFieldSize + shared.CheckSumFieldSize if dataLength == byteBuffer.Len()-trailerOffset-len(shared.MagicTrailer) { allBytes := byteBuffer.Bytes() trailer := string(allBytes[trailerOffset+dataLength:]) if trailer != shared.MagicTrailer { return nil, nil, 0, fmt.Errorf("invalid trailer bytes: %v", trailer) } lengthBytes := allBytes[0:shared.LengthFieldSize] dataBytes := allBytes[shared.LengthFieldSize : shared.LengthFieldSize+dataLength] checkSum := allBytes[shared.LengthFieldSize+dataLength] calculatedChecksum := datastructures.CalculateXorCheckSum([][]byte{lengthBytes, dataBytes}) if calculatedChecksum != checkSum { return nil, nil, 0, fmt.Errorf("calculated checksum mismatch 0x%x vs 0x%x", calculatedChecksum, checkSum) } if err := shared.SendBytes(ph.serialConnection, []byte{shared.AckByte}); err != nil { return nil, nil, 0, fmt.Errorf("could not send ACK byte: %w", err) } return &lengthBytes, &dataBytes, checkSum, nil } } }