Allow users to define what extensions CA certs will have. Skip any files that don't have the right extension.

This commit is contained in:
Renan DelValle 2020-02-24 16:12:56 -08:00
parent 3fa2a20fe4
commit 6cdcbcb5db
No known key found for this signature in database
GPG key ID: C240AD6D6F443EC9
3 changed files with 67 additions and 21 deletions

View file

@ -18,14 +18,11 @@ package realis
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
@ -117,6 +114,7 @@ type config struct {
logger *LevelLogger
insecureSkipVerify bool
certspath string
certExtensions map[string]struct{}
clientKey, clientCert string
options []ClientOption
debug bool
@ -229,6 +227,18 @@ func ClientCerts(clientKey, clientCert string) ClientOption {
}
}
// CertExtensions configures gorealis to consider files with the given extensions when
// loading certificates from the cert path.
func CertExtensions(extensions ...string) ClientOption {
extensionsLookup := make(map[string]struct{})
for _, ext := range extensions {
extensionsLookup[ext] = struct{}{}
}
return func(config *config) {
config.certExtensions = extensionsLookup
}
}
// ZookeeperOptions allows users to override default settings for connecting to Zookeeper.
// See zk.go for what is possible to set as an option.
func ZookeeperOptions(opts ...ZKOpt) ClientOption {
@ -311,6 +321,7 @@ func NewRealisClient(options ...ClientOption) (Realis, error) {
config.timeoutms = 10000
config.backoff = defaultBackoff
config.logger = &LevelLogger{logger: log.New(os.Stdout, "realis: ", log.Ltime|log.Ldate|log.LUTC)}
config.certExtensions = map[string]struct{}{"crt": {}, "pem": {}, "key": {}}
// Save options to recreate client if a connection error happens
config.options = options
@ -433,23 +444,6 @@ func GetDefaultClusterFromZKUrl(zkurl string) *Cluster {
}
}
func createCertPool(certPath string) (*x509.CertPool, error) {
globalRootCAs := x509.NewCertPool()
caFiles, err := ioutil.ReadDir(certPath)
if err != nil {
return nil, err
}
for _, cert := range caFiles {
caPathFile := filepath.Join(certPath, cert.Name())
caCert, err := ioutil.ReadFile(caPathFile)
if err != nil {
return nil, err
}
globalRootCAs.AppendCertsFromPEM(caCert)
}
return globalRootCAs, nil
}
// Creates a default Thrift Transport object for communications in gorealis using an HTTP Post Client
func defaultTTransport(url string, timeoutMs int, config *config) (thrift.TTransport, error) {
var transport http.Transport
@ -457,7 +451,7 @@ func defaultTTransport(url string, timeoutMs int, config *config) (thrift.TTrans
tlsConfig := &tls.Config{InsecureSkipVerify: config.insecureSkipVerify}
if config.certspath != "" {
rootCAs, err := createCertPool(config.certspath)
rootCAs, err := createCertPool(config.certspath, config.certExtensions)
if err != nil {
config.logger.Println("error occurred couldn't fetch certs")
return nil, err

40
util.go
View file

@ -1,7 +1,11 @@
package realis
import (
"crypto/x509"
"io/ioutil"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/paypal/gorealis/gen-go/apache/aurora"
@ -64,6 +68,42 @@ func init() {
AwaitingPulseJobUpdateStates[status] = true
}
}
func createCertPool(path string, extensions map[string]struct{}) (*x509.CertPool, error) {
certPool := x509.NewCertPool()
_, err := os.Stat(path)
if err != nil {
return nil, errors.New("given certs path doesn't exist")
}
caFiles, err := ioutil.ReadDir(path)
if err != nil {
return nil, err
}
if len(caFiles) == 0 {
return nil, errors.New("no possible certs found in " + path)
}
for _, cert := range caFiles {
// Skip directories
if cert.IsDir() {
continue
}
// Skip any files that do not contain the right extension
if _, ok := extensions[filepath.Ext(cert.Name())]; !ok {
continue
}
caCert, err := ioutil.ReadFile(filepath.Join(path, cert.Name()))
if err != nil {
return nil, err
}
certPool.AppendCertsFromPEM(caCert)
}
return certPool, nil
}
func validateAuroraURL(location string) (string, error) {

View file

@ -100,3 +100,15 @@ func TestCurrentBatchCalculator(t *testing.T) {
assert.Equal(t, 0, curBatch)
})
}
func TestCertPoolCreator(t *testing.T) {
extensions := map[string]struct{}{"key": {}, "pem": {}, "crt": {}}
_, err := createCertPool("examples/certs", extensions)
assert.NoError(t, err)
t.Run("badDir", func(t *testing.T) {
_, err := createCertPool("idontexist", extensions)
assert.Error(t, err)
})
}