diff --git a/realis.go b/realis.go index 8f2767d..fe4c5e8 100644 --- a/realis.go +++ b/realis.go @@ -18,14 +18,11 @@ package realis import ( "context" "crypto/tls" - "crypto/x509" "encoding/base64" "fmt" - "io/ioutil" "log" "net/http" "os" - "path/filepath" "sort" "strings" "sync" @@ -117,6 +114,7 @@ type config struct { logger *LevelLogger insecureSkipVerify bool certspath string + certExtensions map[string]struct{} clientKey, clientCert string options []ClientOption debug bool @@ -229,6 +227,18 @@ func ClientCerts(clientKey, clientCert string) ClientOption { } } +// CertExtensions configures gorealis to consider files with the given extensions when +// loading certificates from the cert path. +func CertExtensions(extensions ...string) ClientOption { + extensionsLookup := make(map[string]struct{}) + for _, ext := range extensions { + extensionsLookup[ext] = struct{}{} + } + return func(config *config) { + config.certExtensions = extensionsLookup + } +} + // ZookeeperOptions allows users to override default settings for connecting to Zookeeper. // See zk.go for what is possible to set as an option. func ZookeeperOptions(opts ...ZKOpt) ClientOption { @@ -311,6 +321,7 @@ func NewRealisClient(options ...ClientOption) (Realis, error) { config.timeoutms = 10000 config.backoff = defaultBackoff config.logger = &LevelLogger{logger: log.New(os.Stdout, "realis: ", log.Ltime|log.Ldate|log.LUTC)} + config.certExtensions = map[string]struct{}{".crt": {}, ".pem": {}, ".key": {}} // Save options to recreate client if a connection error happens config.options = options @@ -433,23 +444,6 @@ func GetDefaultClusterFromZKUrl(zkurl string) *Cluster { } } -func createCertPool(certPath string) (*x509.CertPool, error) { - globalRootCAs := x509.NewCertPool() - caFiles, err := ioutil.ReadDir(certPath) - if err != nil { - return nil, err - } - for _, cert := range caFiles { - caPathFile := filepath.Join(certPath, cert.Name()) - caCert, err := ioutil.ReadFile(caPathFile) - if err != nil { - return nil, err - } - globalRootCAs.AppendCertsFromPEM(caCert) - } - return globalRootCAs, nil -} - // Creates a default Thrift Transport object for communications in gorealis using an HTTP Post Client func defaultTTransport(url string, timeoutMs int, config *config) (thrift.TTransport, error) { var transport http.Transport @@ -457,7 +451,7 @@ func defaultTTransport(url string, timeoutMs int, config *config) (thrift.TTrans tlsConfig := &tls.Config{InsecureSkipVerify: config.insecureSkipVerify} if config.certspath != "" { - rootCAs, err := createCertPool(config.certspath) + rootCAs, err := createCertPool(config.certspath, config.certExtensions) if err != nil { config.logger.Println("error occurred couldn't fetch certs") return nil, err diff --git a/util.go b/util.go index 989f8e8..19930e2 100644 --- a/util.go +++ b/util.go @@ -1,7 +1,11 @@ package realis import ( + "crypto/x509" + "io/ioutil" "net/url" + "os" + "path/filepath" "strings" "github.com/paypal/gorealis/gen-go/apache/aurora" @@ -65,6 +69,49 @@ func init() { } } +// createCertPool will attempt to load certificates into a certificate pool from a given directory. +// Only files with an extension contained in the extension map are considered. +// This function ignores any files that cannot be read successfully or cannot be added to the certPool +// successfully. +func createCertPool(path string, extensions map[string]struct{}) (*x509.CertPool, error) { + _, err := os.Stat(path) + if err != nil { + return nil, errors.Wrap(err, "unable to load certificates") + } + + caFiles, err := ioutil.ReadDir(path) + if err != nil { + return nil, err + } + + certPool := x509.NewCertPool() + loadedCerts := 0 + for _, cert := range caFiles { + // Skip directories + if cert.IsDir() { + continue + } + + // Skip any files that do not contain the right extension + if _, ok := extensions[filepath.Ext(cert.Name())]; !ok { + continue + } + + pem, err := ioutil.ReadFile(filepath.Join(path, cert.Name())) + if err != nil { + continue + } + + if certPool.AppendCertsFromPEM(pem) { + loadedCerts++ + } + } + if loadedCerts == 0 { + return nil, errors.New("no certificates were able to be successfully loaded") + } + return certPool, nil +} + func validateAuroraURL(location string) (string, error) { // If no protocol defined, assume http @@ -92,7 +139,7 @@ func validateAuroraURL(location string) (string, error) { return "", errors.Errorf("only protocols http and https are supported %v\n", u.Scheme) } - // This could theoretically be elsewhwere but we'll be strict for the sake of simplicty + // This could theoretically be elsewhere but we'll be strict for the sake of simplicity if u.Path != apiPath { return "", errors.Errorf("expected /api path %v\n", u.Path) } diff --git a/util_test.go b/util_test.go index b8341b2..2906c42 100644 --- a/util_test.go +++ b/util_test.go @@ -100,3 +100,15 @@ func TestCurrentBatchCalculator(t *testing.T) { assert.Equal(t, 0, curBatch) }) } + +func TestCertPoolCreator(t *testing.T) { + extensions := map[string]struct{}{".crt": {}} + + _, err := createCertPool("examples/certs", extensions) + assert.NoError(t, err) + + t.Run("badDir", func(t *testing.T) { + _, err := createCertPool("idontexist", extensions) + assert.Error(t, err) + }) +}