Configure golangci-lint and fix warnings
This commit is contained in:
parent
ecd1846975
commit
2e467b3d2e
20 changed files with 915 additions and 559 deletions
|
@ -2,6 +2,7 @@ package protocol
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
@ -13,17 +14,24 @@ import (
|
|||
"git.cacert.org/cacert-gosigner/shared"
|
||||
)
|
||||
|
||||
const (
|
||||
waitForHeader = 20
|
||||
waitForData = 5
|
||||
waitForHandShake = 120
|
||||
bufferSize = 2048
|
||||
)
|
||||
|
||||
type SignerProtocolHandler interface {
|
||||
io.Closer
|
||||
HandleSignerProtocol() error
|
||||
}
|
||||
|
||||
type signerProtocolConfig struct {
|
||||
type SignerProtocolConfig struct {
|
||||
BufferSize int
|
||||
}
|
||||
|
||||
func NewSignerProtocolConfig() *signerProtocolConfig {
|
||||
return &signerProtocolConfig{BufferSize: 2048}
|
||||
func NewSignerProtocolConfig() *SignerProtocolConfig {
|
||||
return &SignerProtocolConfig{BufferSize: bufferSize}
|
||||
}
|
||||
|
||||
type SignerProtocolRequestChannel struct {
|
||||
|
@ -39,6 +47,7 @@ func NewSignerProtocolRequestChannel() *SignerProtocolRequestChannel {
|
|||
func (rc *SignerProtocolRequestChannel) SafeClose() {
|
||||
rc.mutex.Lock()
|
||||
defer rc.mutex.Unlock()
|
||||
|
||||
if !rc.closed {
|
||||
close(rc.C)
|
||||
rc.closed = true
|
||||
|
@ -48,6 +57,7 @@ func (rc *SignerProtocolRequestChannel) SafeClose() {
|
|||
func (rc *SignerProtocolRequestChannel) IsClosed() bool {
|
||||
rc.mutex.Lock()
|
||||
defer rc.mutex.Unlock()
|
||||
|
||||
return rc.closed
|
||||
}
|
||||
|
||||
|
@ -55,7 +65,7 @@ type protocolHandler struct {
|
|||
requestChannel *SignerProtocolRequestChannel
|
||||
responseChannel *chan *datastructures.SignerResponse
|
||||
serialConnection io.ReadWriteCloser
|
||||
config *signerProtocolConfig
|
||||
config *SignerProtocolConfig
|
||||
}
|
||||
|
||||
type UnExpectedAcknowledgeByte struct {
|
||||
|
@ -75,6 +85,7 @@ func (e UnExpectedAcknowledgeByte) Error() string {
|
|||
func (ph *protocolHandler) Close() error {
|
||||
close(*ph.responseChannel)
|
||||
ph.requestChannel.SafeClose()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -82,7 +93,7 @@ func NewProtocolHandler(
|
|||
requests *SignerProtocolRequestChannel,
|
||||
response *chan *datastructures.SignerResponse,
|
||||
serialConnection io.ReadWriteCloser,
|
||||
config *signerProtocolConfig,
|
||||
config *SignerProtocolConfig,
|
||||
) SignerProtocolHandler {
|
||||
return &protocolHandler{
|
||||
requestChannel: requests,
|
||||
|
@ -93,85 +104,115 @@ func NewProtocolHandler(
|
|||
}
|
||||
|
||||
func (ph *protocolHandler) HandleSignerProtocol() error {
|
||||
for {
|
||||
select {
|
||||
case request := <-ph.requestChannel.C:
|
||||
log.Tracef("handle request %+v", request)
|
||||
var err error
|
||||
var lengthBytes, responseBytes *[]byte
|
||||
var checksum byte
|
||||
for request := range ph.requestChannel.C {
|
||||
log.Tracef("handle request %+v", request)
|
||||
|
||||
if err = ph.sendHandshake(); err != nil {
|
||||
switch err.(type) {
|
||||
case UnExpectedAcknowledgeByte:
|
||||
log.Errorf("unexpected handshake byte: 0x%x", err.(UnExpectedAcknowledgeByte).ResponseByte)
|
||||
// TODO drain input
|
||||
}
|
||||
return err
|
||||
var (
|
||||
err error
|
||||
lengthBytes, responseBytes *[]byte
|
||||
checksum byte
|
||||
)
|
||||
|
||||
if err = ph.sendHandshake(); err != nil {
|
||||
var e *UnExpectedHandshakeByte
|
||||
if errors.As(err, &e) {
|
||||
log.Errorf("unexpected handshake byte: 0x%x", e.ResponseByte)
|
||||
}
|
||||
requestBytes := request.Serialize()
|
||||
if err = ph.sendRequest(&requestBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = ph.waitForResponseHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
if lengthBytes, responseBytes, checksum, err = ph.readResponse(); err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := datastructures.SignerResponseFromData(*lengthBytes, *responseBytes, checksum)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*ph.responseChannel <- response
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
requestBytes := request.Serialize()
|
||||
|
||||
if err = ph.sendRequest(&requestBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = ph.waitForResponseHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if lengthBytes, responseBytes, checksum, err = ph.readResponse(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := datastructures.SignerResponseFromData(*lengthBytes, *responseBytes, checksum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create response: %w", err)
|
||||
}
|
||||
|
||||
*ph.responseChannel <- response
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ph *protocolHandler) sendHandshake() (err error) {
|
||||
var bytesWritten, bytesRead int
|
||||
if bytesWritten, err = ph.serialConnection.Write([]byte{shared.HandshakeByte}); err != nil {
|
||||
return
|
||||
} else {
|
||||
log.Tracef("wrote %d bytes of handshake info", bytesWritten)
|
||||
func (ph *protocolHandler) sendHandshake() error {
|
||||
var (
|
||||
bytesWritten, bytesRead int
|
||||
err error
|
||||
)
|
||||
|
||||
data := make([]byte, 0)
|
||||
|
||||
bytesWritten, err = ph.serialConnection.Write([]byte{shared.HandshakeByte})
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not send handshake byte: %w", err)
|
||||
}
|
||||
data := make([]byte, 1)
|
||||
if bytesRead, err = ph.serialConnection.Read(data); err != nil {
|
||||
return
|
||||
|
||||
log.Tracef("wrote %d bytes of handshake info", bytesWritten)
|
||||
|
||||
bytesRead, err = ph.serialConnection.Read(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not receieve ACK byte: %w", err)
|
||||
}
|
||||
|
||||
log.Tracef("%d bytes read", bytesRead)
|
||||
|
||||
if bytesRead != 1 || data[0] != shared.AckByte {
|
||||
log.Warnf("received invalid handshake byte 0x%x", data[0])
|
||||
|
||||
return UnExpectedAcknowledgeByte{data[0]}
|
||||
}
|
||||
return
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ph *protocolHandler) sendRequest(requestBytes *[]byte) error {
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
if length, err := ph.serialConnection.Write(*requestBytes); err != nil {
|
||||
return err
|
||||
} else {
|
||||
log.Tracef("wrote %d request bytes", length)
|
||||
}
|
||||
|
||||
if length, err := ph.serialConnection.Write([]byte{
|
||||
datastructures.CalculateXorCheckSum([][]byte{*requestBytes}),
|
||||
}); err != nil {
|
||||
return err
|
||||
} else {
|
||||
log.Tracef("wrote %d checksum bytes", length)
|
||||
}
|
||||
|
||||
if length, err := ph.serialConnection.Write([]byte(shared.MagicTrailer)); err != nil {
|
||||
return err
|
||||
} else {
|
||||
log.Tracef("wrote %d trailer bytes", length)
|
||||
}
|
||||
header, err := shared.ReceiveBytes(ph.serialConnection, 1, 20*time.Second)
|
||||
n, err = ph.serialConnection.Write(*requestBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("could not send request bytes: %w", err)
|
||||
}
|
||||
|
||||
log.Tracef("wrote %d request bytes", n)
|
||||
|
||||
n, err = ph.serialConnection.Write([]byte{
|
||||
datastructures.CalculateXorCheckSum([][]byte{*requestBytes}),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not send checksum byte: %w", err)
|
||||
}
|
||||
|
||||
log.Tracef("wrote %d checksum bytes", n)
|
||||
|
||||
n, err = ph.serialConnection.Write([]byte(shared.MagicTrailer))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not send trailer bytes: %w", err)
|
||||
}
|
||||
|
||||
log.Tracef("wrote %d trailer bytes", n)
|
||||
|
||||
header, err := shared.ReceiveBytes(ph.serialConnection, 1, waitForHeader*time.Second)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not read header bytes: %w", err)
|
||||
}
|
||||
|
||||
switch header[0] {
|
||||
case shared.AckByte:
|
||||
return nil
|
||||
|
@ -182,62 +223,78 @@ func (ph *protocolHandler) sendRequest(requestBytes *[]byte) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (ph *protocolHandler) waitForResponseHandshake() (err error) {
|
||||
data, err := shared.ReceiveBytes(ph.serialConnection, 1, 120*time.Second)
|
||||
func (ph *protocolHandler) waitForResponseHandshake() error {
|
||||
data, err := shared.ReceiveBytes(ph.serialConnection, 1, waitForHandShake*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) != 1 || data[0] != shared.HandshakeByte {
|
||||
log.Warnf("received invalid handshake byte 0x%x", data[0])
|
||||
return UnExpectedHandshakeByte{data[0]}
|
||||
}
|
||||
if err = shared.SendBytes(ph.serialConnection, []byte{shared.AckByte}); err != nil {
|
||||
return
|
||||
return fmt.Errorf("could not receive handshake byte: %w", err)
|
||||
}
|
||||
|
||||
return
|
||||
if len(data) != 1 || data[0] != shared.HandshakeByte {
|
||||
log.Warnf("received invalid handshake byte 0x%x", data[0])
|
||||
|
||||
return UnExpectedHandshakeByte{data[0]}
|
||||
}
|
||||
|
||||
if err = shared.SendBytes(ph.serialConnection, []byte{shared.AckByte}); err != nil {
|
||||
return fmt.Errorf("could not send ACK: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ph *protocolHandler) readResponse() (*[]byte, *[]byte, byte, error) {
|
||||
dataLength := -1
|
||||
var lengthBuffer = bytes.NewBuffer(make([]byte, 0))
|
||||
var byteBuffer = bytes.NewBuffer(make([]byte, 0))
|
||||
|
||||
var (
|
||||
lengthBuffer = bytes.NewBuffer(make([]byte, 0))
|
||||
byteBuffer = bytes.NewBuffer(make([]byte, 0))
|
||||
)
|
||||
|
||||
for {
|
||||
readBuffer, err := shared.ReceiveBytes(ph.serialConnection, ph.config.BufferSize, 5*time.Second)
|
||||
readBuffer, err := shared.ReceiveBytes(ph.serialConnection, ph.config.BufferSize, waitForData*time.Second)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
return nil, nil, 0, fmt.Errorf("could not receive response: %w", err)
|
||||
}
|
||||
|
||||
bytesRead := len(readBuffer)
|
||||
if bytesRead > 0 {
|
||||
byteBuffer.Write(readBuffer[0:bytesRead])
|
||||
|
||||
for _, b := range readBuffer {
|
||||
if lengthBuffer.Len() < 3 {
|
||||
if lengthBuffer.Len() < shared.LengthFieldSize {
|
||||
lengthBuffer.WriteByte(b)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if dataLength < 0 && lengthBuffer.Len() == 3 {
|
||||
|
||||
if dataLength < 0 && lengthBuffer.Len() == shared.LengthFieldSize {
|
||||
dataLength = datastructures.Decode24BitLength(lengthBuffer.Bytes())
|
||||
log.Tracef("expect to read %d data bytes", dataLength)
|
||||
}
|
||||
if dataLength == byteBuffer.Len()-4-len(shared.MagicTrailer) {
|
||||
|
||||
trailerOffset := shared.LengthFieldSize + shared.CheckSumFieldSize
|
||||
|
||||
if dataLength == byteBuffer.Len()-trailerOffset-len(shared.MagicTrailer) {
|
||||
allBytes := byteBuffer.Bytes()
|
||||
trailer := string(allBytes[4+dataLength:])
|
||||
|
||||
trailer := string(allBytes[trailerOffset+dataLength:])
|
||||
if trailer != shared.MagicTrailer {
|
||||
return nil, nil, 0, fmt.Errorf("invalid trailer bytes: %v", trailer)
|
||||
}
|
||||
lengthBytes := allBytes[0:3]
|
||||
dataBytes := allBytes[3 : 3+dataLength]
|
||||
checkSum := allBytes[3+dataLength]
|
||||
|
||||
lengthBytes := allBytes[0:shared.LengthFieldSize]
|
||||
dataBytes := allBytes[shared.LengthFieldSize : shared.LengthFieldSize+dataLength]
|
||||
checkSum := allBytes[shared.LengthFieldSize+dataLength]
|
||||
|
||||
calculatedChecksum := datastructures.CalculateXorCheckSum([][]byte{lengthBytes, dataBytes})
|
||||
if calculatedChecksum != checkSum {
|
||||
return nil, nil, 0, fmt.Errorf("calculated checksum mismatch 0x%x vs 0x%x", calculatedChecksum, checkSum)
|
||||
}
|
||||
|
||||
if err := shared.SendBytes(ph.serialConnection, []byte{shared.AckByte}); err != nil {
|
||||
return nil, nil, 0, err
|
||||
return nil, nil, 0, fmt.Errorf("could not send ACK byte: %w", err)
|
||||
}
|
||||
|
||||
return &lengthBytes, &dataBytes, checkSum, nil
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue