Refactor client into separate files

Add a main loop, move I/O code into io.go, move configuration into config.go.
Use shared.Decode24BitLength instead of manually decoding block lengths.
Fix response block decoding and checksum validation.
Add constants for commonly used byte values and use these in the signer and
the client.
This commit is contained in:
Jan Dittberner 2020-04-17 19:39:01 +02:00
parent 65855152ce
commit 42d1e6e991
8 changed files with 322 additions and 241 deletions

84
client/config.go Normal file
View file

@ -0,0 +1,84 @@
package main
import (
"fmt"
"io/ioutil"
"time"
"github.com/goburrow/serial"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
type SerialConfig struct {
Address string `yaml:"address"`
BaudRate int `yaml:"baudrate"`
DataBits int `yaml:"databits"`
StopBits int `yaml:"stopbits"`
Parity string `yaml:"parity"`
}
type ClientConfig struct {
Serial SerialConfig `yaml:"serial_config"`
Paranoid bool `yaml:"paranoid"`
Debug bool `yaml:"debug"`
GNUPGBinary string `yaml:"gnupg_bin"`
OpenSSLBinary string `yaml:"openssl_bin"`
MySQLDSN string `yaml:"mysql_dsn"`
}
var defaultConfig = ClientConfig{
Serial: SerialConfig{
Address: "/dev/ttyUSB0",
BaudRate: 115200,
DataBits: 8,
StopBits: 1,
Parity: "N",
},
Paranoid: false,
Debug: false,
OpenSSLBinary: "/usr/bin/openssl",
GNUPGBinary: "/usr/bin/gpg",
MySQLDSN: "<username>:<password>@/database?parseTime=true",
}
func generateExampleConfig(configFile string) (config *ClientConfig, err error) {
config = &defaultConfig
configBytes, err := yaml.Marshal(config)
if err != nil {
logrus.Errorf("could not generate configuration data")
return
}
logrus.Infof("example data for %s:\n\n---\n%s\n", configFile, configBytes)
return
}
func readConfig(configFile string) (config *ClientConfig, err error) {
source, err := ioutil.ReadFile(configFile)
if err != nil {
logrus.Errorf("opening configuration file failed: %v", err)
if exampleConfig, err := generateExampleConfig(configFile); err != nil {
return nil, err
} else {
logrus.Info("starting with default config")
return exampleConfig, nil
}
}
if err := yaml.Unmarshal(source, &config); err != nil {
return nil, fmt.Errorf("loading configuration file failed: %v", err)
}
return config, nil
}
func fillSerialConfig(clientConfig *ClientConfig) *serial.Config {
return &serial.Config{
Address: clientConfig.Serial.Address,
BaudRate: clientConfig.Serial.BaudRate,
DataBits: clientConfig.Serial.DataBits,
StopBits: clientConfig.Serial.StopBits,
Parity: clientConfig.Serial.Parity,
Timeout: 30 * time.Second,
}
}

98
client/io.go Normal file
View file

@ -0,0 +1,98 @@
package main
import (
"errors"
"fmt"
"github.com/goburrow/serial"
log "github.com/sirupsen/logrus"
"git.cacert.org/cacert-gosigner/datastructures"
"git.cacert.org/cacert-gosigner/shared"
)
func sendHandShake(port serial.Port) error {
log.Debug("Shaking hands ...")
if length, err := port.Write([]byte{shared.HandshakeByte}); err != nil {
return fmt.Errorf("could not write handshake byte: %v", err)
} else {
log.Tracef("wrote %d handshake bytes", length)
}
handShakeResponse := make([]byte, 1)
if length, err := port.Read(handShakeResponse); err != nil {
return fmt.Errorf("failed to read handshake response: %v", err)
} else {
log.Tracef("read %d bytes", length)
}
if handShakeResponse[0] != shared.AckByte {
return fmt.Errorf("invalid handshake response expected 0x10 received %x", handShakeResponse[0])
}
log.Debug("Handshake successful")
return nil
}
func receiveResponse(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error) {
header, err := shared.ReceiveBytes(port, 1, 20)
if err != nil {
*errorChan <- err
return
}
if header[0] != shared.HandshakeByte {
*errorChan <- fmt.Errorf("unexpected byte 0x%x expected 0x%x", header, shared.HandshakeByte)
}
log.Tracef("received handshake byte")
if _, err := (*port).Write([]byte{shared.AckByte}); err != nil {
*errorChan <- errors.New("could not write ACK")
return
}
log.Tracef("sent ACK byte")
lengthBytes, err := shared.ReceiveBytes(port, 3, 2)
if err != nil {
*errorChan <- err
return
}
blockLength := datastructures.Decode24BitLength(lengthBytes)
log.Tracef("received block length %d", blockLength)
blockData, err := shared.ReceiveBytes(port, blockLength, 5)
if err != nil {
*errorChan <- err
return
}
log.Tracef("received bytes %v", blockData)
checkSum, err := shared.ReceiveBytes(port, 1, 2)
if err != nil {
*errorChan <- err
return
}
log.Tracef("received checksum 0x%x", checkSum[0])
trailer, err := shared.ReceiveBytes(port, len(shared.MagicTrailer), 2)
if err != nil {
*errorChan <- err
return
}
if string(trailer) != shared.MagicTrailer {
*errorChan <- errors.New("expected trailer bytes not found")
return
}
log.Tracef("received valid trailer bytes")
if _, err := (*port).Write([]byte{shared.AckByte}); err != nil {
*errorChan <- fmt.Errorf("could not write ACK byte: %v", err)
return
}
log.Tracef("sent ACK byte")
signerResponse, err := datastructures.SignerResponseFromData(lengthBytes, blockData, checkSum[0])
if err != nil {
*errorChan <- err
return
}
log.Infof("received response of type %s", signerResponse.Action)
*responseChan <- *signerResponse
}

View file

@ -1,86 +1,17 @@
package main package main
import ( import (
"encoding/binary"
"errors"
"flag" "flag"
"fmt" "fmt"
"time"
"git.cacert.org/cacert-gosigner/datastructures" "git.cacert.org/cacert-gosigner/datastructures"
"git.cacert.org/cacert-gosigner/shared" "git.cacert.org/cacert-gosigner/shared"
"io/ioutil"
"time"
"github.com/goburrow/serial" "github.com/goburrow/serial"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
) )
type SerialConfig struct {
Address string `yaml:"address"`
BaudRate int `yaml:"baudrate"`
DataBits int `yaml:"databits"`
StopBits int `yaml:"stopbits"`
Parity string `yaml:"parity"`
}
type ClientConfig struct {
Serial SerialConfig `yaml:"serial_config"`
Paranoid bool `yaml:"paranoid"`
Debug bool `yaml:"debug"`
GNUPGBinary string `yaml:"gnupg_bin"`
OpenSSLBinary string `yaml:"openssl_bin"`
MySQLDSN string `yaml:"mysql_dsn"`
}
var defaultConfig = ClientConfig{
Serial: SerialConfig{
Address: "/dev/ttyUSB0",
BaudRate: 115200,
DataBits: 8,
StopBits: 1,
Parity: "N",
},
Paranoid: false,
Debug: false,
OpenSSLBinary: "/usr/bin/openssl",
GNUPGBinary: "/usr/bin/gpg",
MySQLDSN: "<username>:<password>@/database?parseTime=true",
}
const HandshakeByte = 0x02
const AckByte = 0x10
const NackByte = 0x11
func readConfig(configFile string) (config *ClientConfig, err error) {
source, err := ioutil.ReadFile(configFile)
if err != nil {
log.Errorf("opening configuration file failed: %v", err)
if exampleConfig, err := generateExampleConfig(configFile); err != nil {
return nil, err
} else {
log.Info("starting with default config")
return exampleConfig, nil
}
}
if err := yaml.Unmarshal(source, &config); err != nil {
return nil, fmt.Errorf("loading configuration file failed: %v", err)
}
return config, nil
}
func generateExampleConfig(configFile string) (config *ClientConfig, err error) {
config = &defaultConfig
configBytes, err := yaml.Marshal(config)
if err != nil {
log.Errorf("could not generate configuration data")
return
}
log.Infof("example data for %s:\n\n---\n%s\n", configFile, configBytes)
return
}
func main() { func main() {
var configFile string var configFile string
@ -94,177 +25,127 @@ func main() {
if clientConfig, err = readConfig(configFile); err != nil { if clientConfig, err = readConfig(configFile); err != nil {
log.Panic(err) log.Panic(err)
} }
serialConfig = &serial.Config{ serialConfig = fillSerialConfig(clientConfig)
Address: clientConfig.Serial.Address,
BaudRate: clientConfig.Serial.BaudRate,
DataBits: clientConfig.Serial.DataBits,
StopBits: clientConfig.Serial.StopBits,
Parity: clientConfig.Serial.Parity,
Timeout: 30 * time.Second,
}
if clientConfig.Debug { if clientConfig.Debug {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
} }
log.Debugf("connecting %+v", serialConfig) log.Infof("connecting to %s", serialConfig.Address)
log.Tracef("serial parameters %v", serialConfig)
port, err := serial.Open(serialConfig) port, err := serial.Open(serialConfig)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
log.Debug("connected") log.Debug("serial port connected")
defer func() { defer func() {
err := port.Close() err := port.Close()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
log.Debug("closed") log.Debug("serial port closed")
}() }()
request := datastructures.NewNulRequest()
timeout := time.After(2700 * time.Millisecond)
errorChannel := make(chan error, 1) errorChannel := make(chan error, 1)
responseChannel := make(chan *datastructures.SignerResponse, 1) responseChannel := make(chan datastructures.SignerResponse, 1)
crlCheck := 0
go func() { log.Debug("starting main loop")
if response, err := SendRequest(port, request); err != nil {
errorChannel <- err for {
} else { requestChannel := make(chan datastructures.SignerRequest, 1)
responseChannel <- response
go HandleRequests(&port, &responseChannel, &errorChannel, &requestChannel)
log.Debug("handling GPG database ...")
// HandleGPG(&requestChannel)
log.Debug("issuing certificates ...")
// HandleCertificates(&requestChannel)
log.Debug("revoking certificates ...")
// RevokeCertificates(&requestChannel)
crlCheck++
if crlCheck%100 == 0 {
log.Debug("refresh CRLs ...")
// RefreshCRLs(&requestChannel)
} }
}()
log.Debug("send NUL request to keep connection open")
requestChannel <- *datastructures.NewNulRequest()
select { select {
case <-timeout:
log.Fatal("timeout")
case err := <-errorChannel:
log.Fatal(err)
case response := <-responseChannel: case response := <-responseChannel:
if err := Process(response); err != nil { if err := Process(response); err != nil {
log.Fatal(err) log.Error(err)
}
case err := <-errorChannel:
log.Error(err)
}
log.Debug("sleep for 2.7 seconds")
time.Sleep(2700 * time.Millisecond)
}
}
func Process(response datastructures.SignerResponse) (err error) {
log.Infof("process response of type %s", response.Action)
log.Tracef("process response %v", response)
switch response.Action {
case datastructures.ActionNul:
log.Trace("received response for NUL request")
return
default:
return fmt.Errorf("unsupported action in response 0x%x", response.Action)
}
}
func HandleRequests(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error, requestChan *chan datastructures.SignerRequest) {
for {
select {
case request := <-*requestChan:
SendRequest(port, responseChan, errorChan, &request)
} }
} }
} }
func Process(response *datastructures.SignerResponse) error { func SendRequest(port *serial.Port, responseChan *chan datastructures.SignerResponse, errorChan *chan error, request *datastructures.SignerRequest) {
log.Debugf("process %v", response) log.Tracef("send request %v to serial port %v", *request, *port)
return nil if err := sendHandShake(*port); err != nil {
*errorChan <- err
return
} }
func sendHandShake(port serial.Port) error { requestBytes := request.Serialize()
log.Debug("Shaking hands ...") if length, err := (*port).Write(requestBytes); err != nil {
if length, err := port.Write([]byte{HandshakeByte}); err != nil { *errorChan <- err
return fmt.Errorf("could not write handshake byte: %v", err) return
} else { } else {
log.Debugf("wrote %d handshake bytes", length) log.Tracef("wrote %d request bytes", length)
} }
handShakeResponse := make([]byte, 1)
if length, err := port.Read(handShakeResponse); err != nil { if length, err := (*port).Write([]byte{datastructures.CalculateXorCheckSum([][]byte{requestBytes})}); err != nil {
return fmt.Errorf("failed to read handshake response: %v", err) *errorChan <- err
} else { return
log.Debugf("read %d bytes", length) } else {
} log.Tracef("wrote %d checksum bytes", length)
if handShakeResponse[0] != AckByte { }
return fmt.Errorf("invalid handshake response expected 0x10 received %x", handShakeResponse[0])
} if length, err := (*port).Write([]byte(shared.MagicTrailer)); err != nil {
log.Debug("Handshake successful") *errorChan <- err
return nil return
} else {
log.Tracef("wrote %d trailer bytes", length)
} }
func receiveResponse(port *serial.Port, responseChan *chan []byte, errorChan *chan error) {
header, err := shared.ReceiveBytes(port, 1, 20) header, err := shared.ReceiveBytes(port, 1, 20)
if err != nil { if err != nil {
*errorChan <- err *errorChan <- err
return return
} }
if header[0] != HandshakeByte { if header[0] != shared.AckByte {
*errorChan <- fmt.Errorf("unexpected byte 0x%x expected 0x%x", header, HandshakeByte) *errorChan <- fmt.Errorf("unexpected byte 0x%x expected 0x%x", header, shared.AckByte)
}
if _, err := (*port).Write([]byte{AckByte}); err != nil {
*errorChan <- errors.New("could not write ACK")
return return
} }
lengthBytes, err := shared.ReceiveBytes(port, 3, 2) receiveResponse(port, responseChan, errorChan)
if err != nil {
*errorChan <- err
return
}
blockLength := binary.BigEndian.Uint32([]byte{0x0, lengthBytes[0], lengthBytes[1], lengthBytes[2]})
blockData, err := shared.ReceiveBytes(port, int(blockLength), 5)
if err != nil {
*errorChan <- err
return
}
checkSum, err := shared.ReceiveBytes(port, 1, 2)
if err != nil {
*errorChan <- err
return
}
log.Debugf("block checksum is %d", checkSum)
trailer, err := shared.ReceiveBytes(port, len(shared.MagicTrailer), 2)
if err != nil {
*errorChan <- err
return
}
if string(trailer) != shared.MagicTrailer {
*errorChan <- errors.New("expected trailer bytes not found")
return
}
if _, err := (*port).Write([]byte{AckByte}); err != nil {
*errorChan <- fmt.Errorf("could not write ACK byte: %v", err)
}
*responseChan <- blockData
return
}
func SendRequest(port serial.Port, request *datastructures.SignerRequest) (response *datastructures.SignerResponse, err error) {
log.Debugf("send request %v to serial port %v", request, port)
if err = sendHandShake(port); err != nil {
return nil, err
}
requestBytes := request.Serialize()
if length, err := port.Write(requestBytes); err != nil {
log.Fatal(err)
} else {
log.Debugf("wrote %d request bytes", length)
}
if length, err := port.Write([]byte{datastructures.CalculateXorCheckSum([][]byte{requestBytes})}); err != nil {
log.Fatal(err)
} else {
log.Debugf("wrote %d checksum bytes", length)
}
if length, err := port.Write([]byte(shared.MagicTrailer)); err != nil {
log.Fatal(err)
} else {
log.Debugf("wrote %d trailer bytes", length)
}
header, err := shared.ReceiveBytes(&port, 1, 20)
if err != nil {
return nil, err
}
if header[0] != AckByte {
return nil, fmt.Errorf("unexpected byte 0x%x expected 0x%x", header, AckByte)
}
responseChan := make(chan []byte, 1)
errChan := make(chan error, 1)
go receiveResponse(&port, &responseChan, &errChan)
select {
case responseData := <-responseChan:
log.Debugf("response data: %v", responseData)
case err := <-errChan:
log.Errorf("%v", err)
}
return nil, nil
} }

View file

@ -6,6 +6,15 @@ type Action uint8
const ActionNul = Action(0) const ActionNul = Action(0)
func (a Action) String() string {
switch a {
case ActionNul:
return "NUL"
default:
return "unknown"
}
}
func encode24BitLength(data []byte) []byte { func encode24BitLength(data []byte) []byte {
lengthBytes := make([]byte, 4) lengthBytes := make([]byte, 4)
binary.BigEndian.PutUint32(lengthBytes, uint32(len(data))) binary.BigEndian.PutUint32(lengthBytes, uint32(len(data)))
@ -13,7 +22,7 @@ func encode24BitLength(data []byte) []byte {
} }
// calculate length from 24 bits of data in network byte order // calculate length from 24 bits of data in network byte order
func decode24BitLength(bytes []byte) int { func Decode24BitLength(bytes []byte) int {
return int(binary.BigEndian.Uint32([]byte{0x0, bytes[0], bytes[1], bytes[2]})) return int(binary.BigEndian.Uint32([]byte{0x0, bytes[0], bytes[1], bytes[2]}))
} }

View file

@ -25,19 +25,19 @@ type SignerRequest struct {
const protocolVersion = 1 const protocolVersion = 1
func SignerRequestFromData(lengthBytes []byte, blockData []byte, checkSum byte) (*SignerRequest, error) { func SignerRequestFromData(lengthBytes []byte, blockData []byte, checkSum byte) (*SignerRequest, error) {
headerLength := decode24BitLength(blockData[0:3]) headerLength := Decode24BitLength(blockData[0:3])
headerBytes := blockData[3 : 3+headerLength] headerBytes := blockData[3 : 3+headerLength]
contentBytes := blockData[3+headerLength:] contentBytes := blockData[3+headerLength:]
content1Length := decode24BitLength(contentBytes[0:3]) content1Length := Decode24BitLength(contentBytes[0:3])
content1 := string(contentBytes[3 : 3+content1Length]) content1 := string(contentBytes[3 : 3+content1Length])
content2Offset := 3 + content1Length content2Offset := 3 + content1Length
content2Length := decode24BitLength(contentBytes[content2Offset : content2Offset+3]) content2Length := Decode24BitLength(contentBytes[content2Offset : content2Offset+3])
content2 := string(contentBytes[3+content2Offset : 3+content2Offset+content2Length]) content2 := string(contentBytes[3+content2Offset : 3+content2Offset+content2Length])
content3Offset := 3 + content2Offset + content2Length content3Offset := 3 + content2Offset + content2Length
content3Length := decode24BitLength(contentBytes[content3Offset : content3Offset+3]) content3Length := Decode24BitLength(contentBytes[content3Offset : content3Offset+3])
content3 := string(contentBytes[3+content3Offset : 3+content3Offset+content3Length]) content3 := string(contentBytes[3+content3Offset : 3+content3Offset+content3Length])
calculated := CalculateXorCheckSum([][]byte{lengthBytes, blockData}) calculated := CalculateXorCheckSum([][]byte{lengthBytes, blockData})

View file

@ -17,20 +17,26 @@ type SignerResponse struct {
} }
func SignerResponseFromData(lengthBytes []byte, blockData []byte, checkSum byte) (*SignerResponse, error) { func SignerResponseFromData(lengthBytes []byte, blockData []byte, checkSum byte) (*SignerResponse, error) {
headerLength := decode24BitLength(lengthBytes) if len(blockData) < 3 {
headerBytes := blockData[3 : 3+headerLength] return nil, errors.New("begin of structure corrupt")
}
contentBytes := blockData[3+headerLength:] offset := 0
content1Length := decode24BitLength(contentBytes[0:3]) headerLength := Decode24BitLength(blockData[offset : offset+3])
content1 := string(contentBytes[3 : 3+content1Length]) offset += 3
headerBytes := blockData[offset : offset+headerLength]
offset += headerLength
content2Offset := 3 + content1Length content := make([]string, 3)
content2Length := decode24BitLength(contentBytes[content2Offset : content2Offset+3]) for offset < len(blockData) {
content2 := string(contentBytes[3+content2Offset : 3+content2Offset+content2Length]) dataLength := Decode24BitLength(blockData[offset : offset+3])
if len(blockData)-3 < dataLength {
content3Offset := 3 + content2Offset + content2Length return nil, errors.New("structure cut off")
content3Length := decode24BitLength(contentBytes[content3Offset : content3Offset+3]) }
content3 := string(contentBytes[3+content3Offset : 3+content3Offset+content3Length]) offset += 3
content = append(content, string(blockData[offset:offset+dataLength]))
offset += dataLength
}
calculated := CalculateXorCheckSum([][]byte{lengthBytes, blockData}) calculated := CalculateXorCheckSum([][]byte{lengthBytes, blockData})
if checkSum != calculated { if checkSum != calculated {
@ -42,9 +48,9 @@ func SignerResponseFromData(lengthBytes []byte, blockData []byte, checkSum byte)
Action: Action(headerBytes[1]), Action: Action(headerBytes[1]),
Reserved1: headerBytes[2], Reserved1: headerBytes[2],
Reserved2: headerBytes[3], Reserved2: headerBytes[3],
Content1: content1, Content1: content[0],
Content2: content2, Content2: content[1],
Content3: content3, Content3: content[2],
}, nil }, nil
} }

View file

@ -1,3 +1,7 @@
package shared package shared
const MagicTrailer = "rie4Ech7" const MagicTrailer = "rie4Ech7"
const HandshakeByte = 0x02
const AckByte = 0x10
const NackByte = 0x11

View file

@ -1,15 +1,15 @@
package main package main
import ( import (
"encoding/binary"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"git.cacert.org/cacert-gosigner/datastructures"
"git.cacert.org/cacert-gosigner/shared"
"log" "log"
"time" "time"
"git.cacert.org/cacert-gosigner/datastructures"
"git.cacert.org/cacert-gosigner/shared"
"github.com/goburrow/serial" "github.com/goburrow/serial"
) )
@ -96,7 +96,7 @@ func SendResponse(port *serial.Port, response *datastructures.SignerResponse) er
} }
tryAgain := true tryAgain := true
for ; tryAgain; { for tryAgain {
data := response.Serialize() data := response.Serialize()
if _, err := (*port).Write(data); err != nil { if _, err := (*port).Write(data); err != nil {
return err return err
@ -150,11 +150,11 @@ func Receive(port *serial.Port, commandChan *chan datastructures.SignerRequest,
*errorChan <- err *errorChan <- err
return return
} }
if header[0] != 0x02 { if header[0] != shared.HandshakeByte {
*errorChan <- fmt.Errorf("unexpected byte 0x%x expected 0x02", header) *errorChan <- fmt.Errorf("unexpected byte 0x%x expected 0x%x", header[0], shared.HandshakeByte)
} }
if _, err := (*port).Write([]byte{0x10}); err != nil { if _, err := (*port).Write([]byte{shared.AckByte}); err != nil {
*errorChan <- errors.New("could not write ACK") *errorChan <- errors.New("could not write ACK")
return return
} }
@ -164,8 +164,8 @@ func Receive(port *serial.Port, commandChan *chan datastructures.SignerRequest,
*errorChan <- err *errorChan <- err
return return
} }
blockLength := binary.BigEndian.Uint32([]byte{0x0, lengthBytes[0], lengthBytes[1], lengthBytes[2]}) blockLength := datastructures.Decode24BitLength(lengthBytes)
blockData, err := shared.ReceiveBytes(port, int(blockLength), 5) blockData, err := shared.ReceiveBytes(port, blockLength, 5)
if err != nil { if err != nil {
*errorChan <- err *errorChan <- err
return return
@ -199,4 +199,3 @@ func Receive(port *serial.Port, commandChan *chan datastructures.SignerRequest,
*commandChan <- *command *commandChan <- *command
} }