cacert-gosigner/client/protocol/protocol.go

304 lines
7.3 KiB
Go

package protocol
import (
"bytes"
"errors"
"fmt"
"io"
"sync"
"time"
log "github.com/sirupsen/logrus"
"git.cacert.org/cacert-gosigner/datastructures"
"git.cacert.org/cacert-gosigner/shared"
)
const (
waitForHeader = 20
waitForData = 5
waitForHandShake = 120
bufferSize = 2048
)
type SignerProtocolHandler interface {
io.Closer
HandleSignerProtocol() error
}
type SignerProtocolConfig struct {
BufferSize int
}
func NewSignerProtocolConfig() *SignerProtocolConfig {
return &SignerProtocolConfig{BufferSize: bufferSize}
}
type SignerProtocolRequestChannel struct {
C chan *datastructures.SignerRequest
closed bool
mutex sync.Mutex
}
func NewSignerProtocolRequestChannel() *SignerProtocolRequestChannel {
return &SignerProtocolRequestChannel{C: make(chan *datastructures.SignerRequest, 1)}
}
func (rc *SignerProtocolRequestChannel) SafeClose() {
rc.mutex.Lock()
defer rc.mutex.Unlock()
if !rc.closed {
close(rc.C)
rc.closed = true
}
}
func (rc *SignerProtocolRequestChannel) IsClosed() bool {
rc.mutex.Lock()
defer rc.mutex.Unlock()
return rc.closed
}
type protocolHandler struct {
requestChannel *SignerProtocolRequestChannel
responseChannel *chan *datastructures.SignerResponse
serialConnection io.ReadWriteCloser
config *SignerProtocolConfig
}
type UnExpectedAcknowledgeByte struct {
ResponseByte byte
}
type UnExpectedHandshakeByte struct {
ResponseByte byte
}
func (e UnExpectedHandshakeByte) Error() string {
return fmt.Sprintf("unexpected handshake byte 0x%x instead of 0x%x", e.ResponseByte, shared.HandshakeByte)
}
func (e UnExpectedAcknowledgeByte) Error() string {
return fmt.Sprintf("unexpected acknowledge byte 0x%x instead of 0x%x", e.ResponseByte, shared.AckByte)
}
func (ph *protocolHandler) Close() error {
close(*ph.responseChannel)
ph.requestChannel.SafeClose()
return nil
}
func NewProtocolHandler(
requests *SignerProtocolRequestChannel,
response *chan *datastructures.SignerResponse,
serialConnection io.ReadWriteCloser,
config *SignerProtocolConfig,
) SignerProtocolHandler {
return &protocolHandler{
requestChannel: requests,
responseChannel: response,
serialConnection: serialConnection,
config: config,
}
}
func (ph *protocolHandler) HandleSignerProtocol() error {
for request := range ph.requestChannel.C {
log.Tracef("handle request %+v", request)
var (
err error
lengthBytes, responseBytes *[]byte
checksum byte
)
if err = ph.sendHandshake(); err != nil {
var e *UnExpectedHandshakeByte
if errors.As(err, &e) {
log.Errorf("unexpected handshake byte: 0x%x", e.ResponseByte)
}
return err
}
requestBytes := request.Serialize()
if err = ph.sendRequest(&requestBytes); err != nil {
return err
}
if err = ph.waitForResponseHandshake(); err != nil {
return err
}
if lengthBytes, responseBytes, checksum, err = ph.readResponse(); err != nil {
return err
}
response, err := datastructures.SignerResponseFromData(*lengthBytes, *responseBytes, checksum)
if err != nil {
return fmt.Errorf("could not create response: %w", err)
}
*ph.responseChannel <- response
}
return nil
}
func (ph *protocolHandler) sendHandshake() error {
var (
bytesWritten, bytesRead int
err error
)
data := make([]byte, 0)
bytesWritten, err = ph.serialConnection.Write([]byte{shared.HandshakeByte})
if err != nil {
return fmt.Errorf("could not send handshake byte: %w", err)
}
log.Tracef("wrote %d bytes of handshake info", bytesWritten)
bytesRead, err = ph.serialConnection.Read(data)
if err != nil {
return fmt.Errorf("could not receieve ACK byte: %w", err)
}
log.Tracef("%d bytes read", bytesRead)
if bytesRead != 1 || data[0] != shared.AckByte {
log.Warnf("received invalid handshake byte 0x%x", data[0])
return UnExpectedAcknowledgeByte{data[0]}
}
return nil
}
func (ph *protocolHandler) sendRequest(requestBytes *[]byte) error {
var (
n int
err error
)
for {
n, err = ph.serialConnection.Write(*requestBytes)
if err != nil {
return fmt.Errorf("could not send request bytes: %w", err)
}
log.Tracef("wrote %d request bytes", n)
n, err = ph.serialConnection.Write([]byte{
datastructures.CalculateXorCheckSum([][]byte{*requestBytes}),
})
if err != nil {
return fmt.Errorf("could not send checksum byte: %w", err)
}
log.Tracef("wrote %d checksum bytes", n)
n, err = ph.serialConnection.Write([]byte(shared.MagicTrailer))
if err != nil {
return fmt.Errorf("could not send trailer bytes: %w", err)
}
log.Tracef("wrote %d trailer bytes", n)
header, err := shared.ReceiveBytes(ph.serialConnection, 1, waitForHeader*time.Second)
if err != nil {
return fmt.Errorf("could not read header bytes: %w", err)
}
switch header[0] {
case shared.AckByte:
return nil
case shared.ResendByte:
default:
return UnExpectedAcknowledgeByte{header[0]}
}
}
}
func (ph *protocolHandler) waitForResponseHandshake() error {
data, err := shared.ReceiveBytes(ph.serialConnection, 1, waitForHandShake*time.Second)
if err != nil {
return fmt.Errorf("could not receive handshake byte: %w", err)
}
if len(data) != 1 || data[0] != shared.HandshakeByte {
log.Warnf("received invalid handshake byte 0x%x", data[0])
return UnExpectedHandshakeByte{data[0]}
}
if err = shared.SendBytes(ph.serialConnection, []byte{shared.AckByte}); err != nil {
return fmt.Errorf("could not send ACK: %w", err)
}
return nil
}
func (ph *protocolHandler) readResponse() (*[]byte, *[]byte, byte, error) {
dataLength := -1
var (
lengthBuffer = bytes.NewBuffer(make([]byte, 0))
byteBuffer = bytes.NewBuffer(make([]byte, 0))
)
for {
readBuffer, err := shared.ReceiveBytes(ph.serialConnection, ph.config.BufferSize, waitForData*time.Second)
if err != nil {
return nil, nil, 0, fmt.Errorf("could not receive response: %w", err)
}
bytesRead := len(readBuffer)
if bytesRead > 0 {
byteBuffer.Write(readBuffer[0:bytesRead])
for _, b := range readBuffer {
if lengthBuffer.Len() < shared.LengthFieldSize {
lengthBuffer.WriteByte(b)
} else {
break
}
}
}
if dataLength < 0 && lengthBuffer.Len() == shared.LengthFieldSize {
dataLength = datastructures.Decode24BitLength(lengthBuffer.Bytes())
log.Tracef("expect to read %d data bytes", dataLength)
}
trailerOffset := shared.LengthFieldSize + shared.CheckSumFieldSize
if dataLength == byteBuffer.Len()-trailerOffset-len(shared.MagicTrailer) {
allBytes := byteBuffer.Bytes()
trailer := string(allBytes[trailerOffset+dataLength:])
if trailer != shared.MagicTrailer {
return nil, nil, 0, fmt.Errorf("invalid trailer bytes: %v", trailer)
}
lengthBytes := allBytes[0:shared.LengthFieldSize]
dataBytes := allBytes[shared.LengthFieldSize : shared.LengthFieldSize+dataLength]
checkSum := allBytes[shared.LengthFieldSize+dataLength]
calculatedChecksum := datastructures.CalculateXorCheckSum([][]byte{lengthBytes, dataBytes})
if calculatedChecksum != checkSum {
return nil, nil, 0, fmt.Errorf("calculated checksum mismatch 0x%x vs 0x%x", calculatedChecksum, checkSum)
}
if err := shared.SendBytes(ph.serialConnection, []byte{shared.AckByte}); err != nil {
return nil, nil, 0, fmt.Errorf("could not send ACK byte: %w", err)
}
return &lengthBytes, &dataBytes, checkSum, nil
}
}
}