package protocol import ( "bytes" "fmt" "git.cacert.org/cacert-gosigner/datastructures" "git.cacert.org/cacert-gosigner/shared" log "github.com/sirupsen/logrus" "io" "sync" "time" ) type SignerProtocolHandler interface { io.Closer HandleSignerProtocol() error } type signerProtocolConfig struct { BufferSize int } func NewSignerProtocolConfig() *signerProtocolConfig { return &signerProtocolConfig{BufferSize: 2048} } type SignerProtocolRequestChannel struct { C chan *datastructures.SignerRequest closed bool mutex sync.Mutex } func NewSignerProtocolRequestChannel() *SignerProtocolRequestChannel { return &SignerProtocolRequestChannel{C: make(chan *datastructures.SignerRequest, 1)} } func (rc *SignerProtocolRequestChannel) SafeClose() { rc.mutex.Lock() defer rc.mutex.Unlock() if !rc.closed { close(rc.C) rc.closed = true } } func (rc *SignerProtocolRequestChannel) IsClosed() bool { rc.mutex.Lock() defer rc.mutex.Unlock() return rc.closed } type protocolHandler struct { requestChannel *SignerProtocolRequestChannel responseChannel *chan *datastructures.SignerResponse serialConnection *io.ReadWriteCloser config *signerProtocolConfig } type UnExpectedAcknowledgeByte struct { ResponseByte byte } type UnExpectedHandshakeByte struct { ResponseByte byte } func (e UnExpectedHandshakeByte) Error() string { return fmt.Sprintf("unexpected handshake byte 0x%x instead of 0x%x", e.ResponseByte, shared.HandshakeByte) } func (e UnExpectedAcknowledgeByte) Error() string { return fmt.Sprintf("unexpected acknowledge byte 0x%x instead of 0x%x", e.ResponseByte, shared.AckByte) } func (ph *protocolHandler) Close() error { close(*ph.responseChannel) ph.requestChannel.SafeClose() return nil } func NewProtocolHandler(requests *SignerProtocolRequestChannel, response *chan *datastructures.SignerResponse, serialConnection *io.ReadWriteCloser, config *signerProtocolConfig) SignerProtocolHandler { return &protocolHandler{ requestChannel: requests, responseChannel: response, serialConnection: serialConnection, config: config, } } func (ph *protocolHandler) HandleSignerProtocol() error { for { select { case request := <-ph.requestChannel.C: log.Tracef("handle request %+v", request) var err error var lengthBytes, responseBytes *[]byte var checksum byte if err = ph.sendHandshake(); err != nil { switch err.(type) { case UnExpectedAcknowledgeByte: log.Errorf("unexpected handshake byte: 0x%x", err.(UnExpectedAcknowledgeByte).ResponseByte) // TODO drain input } return err } requestBytes := request.Serialize() if err = ph.sendRequest(&requestBytes); err != nil { return err } if err = ph.waitForResponseHandshake(); err != nil { return err } if lengthBytes, responseBytes, checksum, err = ph.readResponse(); err != nil { return err } response, err := datastructures.SignerResponseFromData(*lengthBytes, *responseBytes, checksum) if err != nil { return err } *ph.responseChannel <- response } } } func (ph *protocolHandler) sendHandshake() (err error) { var bytesWritten, bytesRead int if bytesWritten, err = (*ph.serialConnection).Write([]byte{shared.HandshakeByte}); err != nil { return } else { log.Tracef("wrote %d bytes of handshake info", bytesWritten) } data := make([]byte, 1) if bytesRead, err = (*ph.serialConnection).Read(data); err != nil { return } log.Tracef("%d bytes read", bytesRead) if bytesRead != 1 || data[0] != shared.AckByte { log.Warnf("received invalid handshake byte 0x%x", data[0]) return UnExpectedAcknowledgeByte{data[0]} } return } func (ph *protocolHandler) sendRequest(requestBytes *[]byte) error { for { if length, err := (*ph.serialConnection).Write(*requestBytes); err != nil { return err } else { log.Tracef("wrote %d request bytes", length) } if length, err := (*ph.serialConnection).Write([]byte{ datastructures.CalculateXorCheckSum([][]byte{*requestBytes}), }); err != nil { return err } else { log.Tracef("wrote %d checksum bytes", length) } if length, err := (*ph.serialConnection).Write([]byte(shared.MagicTrailer)); err != nil { return err } else { log.Tracef("wrote %d trailer bytes", length) } header, err := shared.ReceiveBytes(ph.serialConnection, 1, 20*time.Second) if err != nil { return err } switch header[0] { case shared.AckByte: return nil case shared.ResendByte: default: return UnExpectedAcknowledgeByte{header[0]} } } } func (ph *protocolHandler) waitForResponseHandshake() (err error) { data, err := shared.ReceiveBytes(ph.serialConnection, 1, 120*time.Second) if err != nil { return err } if len(data) != 1 || data[0] != shared.HandshakeByte { log.Warnf("received invalid handshake byte 0x%x", data[0]) return UnExpectedHandshakeByte{data[0]} } if err = shared.SendByte(ph.serialConnection, shared.AckByte); err != nil { return } return } func (ph *protocolHandler) readResponse() (*[]byte, *[]byte, byte, error) { dataLength := -1 var lengthBuffer = bytes.NewBuffer(make([]byte, 0)) var byteBuffer = bytes.NewBuffer(make([]byte, 0)) for { readBuffer, err := shared.ReceiveBytes(ph.serialConnection, ph.config.BufferSize, 5*time.Second) if err != nil { return nil, nil, 0, err } bytesRead := len(readBuffer) if bytesRead > 0 { byteBuffer.Write(readBuffer[0:bytesRead]) for _, b := range readBuffer { if lengthBuffer.Len() < 3 { lengthBuffer.WriteByte(b) } else { break } } } if dataLength < 0 && lengthBuffer.Len() == 3 { dataLength = datastructures.Decode24BitLength(lengthBuffer.Bytes()) log.Tracef("expect to read %d data bytes", dataLength) } if dataLength == byteBuffer.Len()-4-len(shared.MagicTrailer) { allBytes := byteBuffer.Bytes() trailer := string(allBytes[4+dataLength:]) if trailer != shared.MagicTrailer { return nil, nil, 0, fmt.Errorf("invalid trailer bytes: %v", trailer) } lengthBytes := allBytes[0:3] dataBytes := allBytes[3 : 3+dataLength] checkSum := allBytes[3+dataLength] calculatedChecksum := datastructures.CalculateXorCheckSum([][]byte{lengthBytes, dataBytes}) if calculatedChecksum != checkSum { return nil, nil, 0, fmt.Errorf("calculated checksum mismatch 0x%x vs 0x%x", calculatedChecksum, checkSum) } if err := shared.SendByte(ph.serialConnection, shared.AckByte); err != nil { return nil, nil, 0, err } return &lengthBytes, &dataBytes, checkSum, nil } } }