239 lines
6.4 KiB
Go
239 lines
6.4 KiB
Go
package protocol
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"git.cacert.org/cacert-gosigner/datastructures"
|
|
"git.cacert.org/cacert-gosigner/shared"
|
|
log "github.com/sirupsen/logrus"
|
|
"io"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type SignerProtocolHandler interface {
|
|
io.Closer
|
|
HandleSignerProtocol() error
|
|
}
|
|
|
|
type signerProtocolConfig struct {
|
|
BufferSize int
|
|
}
|
|
|
|
func NewSignerProtocolConfig() *signerProtocolConfig {
|
|
return &signerProtocolConfig{BufferSize: 2048}
|
|
}
|
|
|
|
type SignerProtocolRequestChannel struct {
|
|
C chan *datastructures.SignerRequest
|
|
closed bool
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
func NewSignerProtocolRequestChannel() *SignerProtocolRequestChannel {
|
|
return &SignerProtocolRequestChannel{C: make(chan *datastructures.SignerRequest, 1)}
|
|
}
|
|
|
|
func (rc *SignerProtocolRequestChannel) SafeClose() {
|
|
rc.mutex.Lock()
|
|
defer rc.mutex.Unlock()
|
|
if !rc.closed {
|
|
close(rc.C)
|
|
rc.closed = true
|
|
}
|
|
}
|
|
|
|
func (rc *SignerProtocolRequestChannel) IsClosed() bool {
|
|
rc.mutex.Lock()
|
|
defer rc.mutex.Unlock()
|
|
return rc.closed
|
|
}
|
|
|
|
type protocolHandler struct {
|
|
requestChannel *SignerProtocolRequestChannel
|
|
responseChannel *chan *datastructures.SignerResponse
|
|
serialConnection *io.ReadWriteCloser
|
|
config *signerProtocolConfig
|
|
}
|
|
|
|
type UnExpectedAcknowledgeByte struct {
|
|
ResponseByte byte
|
|
}
|
|
type UnExpectedHandshakeByte struct {
|
|
ResponseByte byte
|
|
}
|
|
|
|
func (e UnExpectedHandshakeByte) Error() string {
|
|
return fmt.Sprintf("unexpected handshake byte 0x%x instead of 0x%x", e.ResponseByte, shared.HandshakeByte)
|
|
}
|
|
func (e UnExpectedAcknowledgeByte) Error() string {
|
|
return fmt.Sprintf("unexpected acknowledge byte 0x%x instead of 0x%x", e.ResponseByte, shared.AckByte)
|
|
}
|
|
|
|
func (ph *protocolHandler) Close() error {
|
|
close(*ph.responseChannel)
|
|
ph.requestChannel.SafeClose()
|
|
return nil
|
|
}
|
|
|
|
func NewProtocolHandler(requests *SignerProtocolRequestChannel, response *chan *datastructures.SignerResponse, serialConnection *io.ReadWriteCloser, config *signerProtocolConfig) SignerProtocolHandler {
|
|
return &protocolHandler{
|
|
requestChannel: requests,
|
|
responseChannel: response,
|
|
serialConnection: serialConnection,
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
func (ph *protocolHandler) HandleSignerProtocol() error {
|
|
for {
|
|
select {
|
|
case request := <-ph.requestChannel.C:
|
|
log.Tracef("handle request %+v", request)
|
|
var err error
|
|
var lengthBytes, responseBytes *[]byte
|
|
var checksum byte
|
|
|
|
if err = ph.sendHandshake(); err != nil {
|
|
switch err.(type) {
|
|
case UnExpectedAcknowledgeByte:
|
|
log.Errorf("unexpected handshake byte: 0x%x", err.(UnExpectedAcknowledgeByte).ResponseByte)
|
|
// TODO drain input
|
|
}
|
|
return err
|
|
}
|
|
requestBytes := request.Serialize()
|
|
if err = ph.sendRequest(&requestBytes); err != nil {
|
|
return err
|
|
}
|
|
if err = ph.waitForResponseHandshake(); err != nil {
|
|
return err
|
|
}
|
|
if lengthBytes, responseBytes, checksum, err = ph.readResponse(); err != nil {
|
|
return err
|
|
}
|
|
response, err := datastructures.SignerResponseFromData(*lengthBytes, *responseBytes, checksum)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*ph.responseChannel <- response
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ph *protocolHandler) sendHandshake() (err error) {
|
|
var bytesWritten, bytesRead int
|
|
if bytesWritten, err = (*ph.serialConnection).Write([]byte{shared.HandshakeByte}); err != nil {
|
|
return
|
|
} else {
|
|
log.Tracef("wrote %d bytes of handshake info", bytesWritten)
|
|
}
|
|
data := make([]byte, 1)
|
|
if bytesRead, err = (*ph.serialConnection).Read(data); err != nil {
|
|
return
|
|
}
|
|
log.Tracef("%d bytes read", bytesRead)
|
|
if bytesRead != 1 || data[0] != shared.AckByte {
|
|
log.Warnf("received invalid handshake byte 0x%x", data[0])
|
|
return UnExpectedAcknowledgeByte{data[0]}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (ph *protocolHandler) sendRequest(requestBytes *[]byte) error {
|
|
for {
|
|
if length, err := (*ph.serialConnection).Write(*requestBytes); err != nil {
|
|
return err
|
|
} else {
|
|
log.Tracef("wrote %d request bytes", length)
|
|
}
|
|
|
|
if length, err := (*ph.serialConnection).Write([]byte{
|
|
datastructures.CalculateXorCheckSum([][]byte{*requestBytes}),
|
|
}); err != nil {
|
|
return err
|
|
} else {
|
|
log.Tracef("wrote %d checksum bytes", length)
|
|
}
|
|
|
|
if length, err := (*ph.serialConnection).Write([]byte(shared.MagicTrailer)); err != nil {
|
|
return err
|
|
} else {
|
|
log.Tracef("wrote %d trailer bytes", length)
|
|
}
|
|
header, err := shared.ReceiveBytes(ph.serialConnection, 1, 20*time.Second)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
switch header[0] {
|
|
case shared.AckByte:
|
|
return nil
|
|
case shared.ResendByte:
|
|
default:
|
|
return UnExpectedAcknowledgeByte{header[0]}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ph *protocolHandler) waitForResponseHandshake() (err error) {
|
|
data, err := shared.ReceiveBytes(ph.serialConnection, 1, 120*time.Second)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(data) != 1 || data[0] != shared.HandshakeByte {
|
|
log.Warnf("received invalid handshake byte 0x%x", data[0])
|
|
return UnExpectedHandshakeByte{data[0]}
|
|
}
|
|
if err = shared.SendByte(ph.serialConnection, shared.AckByte); err != nil {
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (ph *protocolHandler) readResponse() (*[]byte, *[]byte, byte, error) {
|
|
dataLength := -1
|
|
var lengthBuffer = bytes.NewBuffer(make([]byte, 0))
|
|
var byteBuffer = bytes.NewBuffer(make([]byte, 0))
|
|
|
|
for {
|
|
readBuffer, err := shared.ReceiveBytes(ph.serialConnection, ph.config.BufferSize, 5*time.Second)
|
|
if err != nil {
|
|
return nil, nil, 0, err
|
|
}
|
|
bytesRead := len(readBuffer)
|
|
if bytesRead > 0 {
|
|
byteBuffer.Write(readBuffer[0:bytesRead])
|
|
for _, b := range readBuffer {
|
|
if lengthBuffer.Len() < 3 {
|
|
lengthBuffer.WriteByte(b)
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if dataLength < 0 && lengthBuffer.Len() == 3 {
|
|
dataLength = datastructures.Decode24BitLength(lengthBuffer.Bytes())
|
|
log.Tracef("expect to read %d data bytes", dataLength)
|
|
}
|
|
if dataLength == byteBuffer.Len()-4-len(shared.MagicTrailer) {
|
|
allBytes := byteBuffer.Bytes()
|
|
trailer := string(allBytes[4+dataLength:])
|
|
if trailer != shared.MagicTrailer {
|
|
return nil, nil, 0, fmt.Errorf("invalid trailer bytes: %v", trailer)
|
|
}
|
|
lengthBytes := allBytes[0:3]
|
|
dataBytes := allBytes[3 : 3+dataLength]
|
|
checkSum := allBytes[3+dataLength]
|
|
calculatedChecksum := datastructures.CalculateXorCheckSum([][]byte{lengthBytes, dataBytes})
|
|
if calculatedChecksum != checkSum {
|
|
return nil, nil, 0, fmt.Errorf("calculated checksum mismatch 0x%x vs 0x%x", calculatedChecksum, checkSum)
|
|
}
|
|
if err := shared.SendByte(ph.serialConnection, shared.AckByte); err != nil {
|
|
return nil, nil, 0, err
|
|
}
|
|
|
|
return &lengthBytes, &dataBytes, checkSum, nil
|
|
}
|
|
}
|
|
}
|