diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 751b54a08..b59b3f94d 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -790,6 +790,9 @@ func (this *Migrator) initiateInspector() (err error) { if this.migrationContext.CliMasterPassword != "" { this.migrationContext.ApplierConnectionConfig.Password = this.migrationContext.CliMasterPassword } + if err := this.migrationContext.ApplierConnectionConfig.RegisterTLSConfig(); err != nil { + return err + } this.migrationContext.Log.Infof("Master forced to be %+v", *this.migrationContext.ApplierConnectionConfig.ImpliedKey) } // validate configs diff --git a/go/logic/throttler.go b/go/logic/throttler.go index 929be9ce2..7fe9026e4 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -215,8 +215,10 @@ func (this *Throttler) collectControlReplicasLag() { } lagResults := make(chan *mysql.ReplicationLagResult, instanceKeyMap.Len()) for replicaKey := range *instanceKeyMap { - connectionConfig := this.migrationContext.InspectorConnectionConfig.Duplicate() - connectionConfig.Key = replicaKey + connectionConfig := this.migrationContext.InspectorConnectionConfig.DuplicateCredentials(replicaKey) + if err := connectionConfig.RegisterTLSConfig(); err != nil { + return &mysql.ReplicationLagResult{Err: err} + } lagResult := &mysql.ReplicationLagResult{Key: connectionConfig.Key} go func() { diff --git a/go/mysql/connection.go b/go/mysql/connection.go index 33bde2b62..f728fc7fe 100644 --- a/go/mysql/connection.go +++ b/go/mysql/connection.go @@ -52,6 +52,16 @@ func (this *ConnectionConfig) DuplicateCredentials(key InstanceKey) *ConnectionC TransactionIsolation: this.TransactionIsolation, Charset: this.Charset, } + + if this.tlsConfig != nil { + config.tlsConfig = &tls.Config{ + ServerName: key.Hostname, + Certificates: this.tlsConfig.Certificates, + RootCAs: this.tlsConfig.RootCAs, + InsecureSkipVerify: this.tlsConfig.InsecureSkipVerify, + } + } + config.ImpliedKey = &config.Key return config } @@ -103,7 +113,20 @@ func (this *ConnectionConfig) UseTLS(caCertificatePath, clientCertificate, clien InsecureSkipVerify: allowInsecure, } - return mysql.RegisterTLSConfig(TLS_CONFIG_KEY, this.tlsConfig) + return this.RegisterTLSConfig() +} + +func (this *ConnectionConfig) RegisterTLSConfig() error { + if this.tlsConfig == nil { + return nil + } + if this.tlsConfig.ServerName == "" { + return errors.New("tlsConfig.ServerName cannot be empty") + } + + var tlsOption = GetDBTLSConfigKey(this.tlsConfig.ServerName) + + return mysql.RegisterTLSConfig(tlsOption, this.tlsConfig) } func (this *ConnectionConfig) TLSConfig() *tls.Config { @@ -122,7 +145,7 @@ func (this *ConnectionConfig) GetDBUri(databaseName string) string { // simplify construction of the DSN below. tlsOption := "false" if this.tlsConfig != nil { - tlsOption = TLS_CONFIG_KEY + tlsOption = GetDBTLSConfigKey(this.tlsConfig.ServerName) } if this.Charset == "" { @@ -142,3 +165,7 @@ func (this *ConnectionConfig) GetDBUri(databaseName string) string { return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", this.User, this.Password, hostname, this.Key.Port, databaseName, strings.Join(connectionParams, "&")) } + +func GetDBTLSConfigKey(tlsServerName string) string { + return fmt.Sprintf("%s-%s", TLS_CONFIG_KEY, tlsServerName) +} diff --git a/go/mysql/connection_test.go b/go/mysql/connection_test.go index 7859c9354..59761b047 100644 --- a/go/mysql/connection_test.go +++ b/go/mysql/connection_test.go @@ -52,7 +52,10 @@ func TestDuplicateCredentials(t *testing.T) { require.Equal(t, 3310, dup.ImpliedKey.Port) require.Equal(t, "gromit", dup.User) require.Equal(t, "penguin", dup.Password) - require.Equal(t, c.tlsConfig, dup.tlsConfig) + require.Equal(t, "otherhost", dup.tlsConfig.ServerName) + require.Equal(t, c.tlsConfig.Certificates, dup.tlsConfig.Certificates) + require.Equal(t, c.tlsConfig.RootCAs, dup.tlsConfig.RootCAs) + require.Equal(t, c.tlsConfig.InsecureSkipVerify, dup.tlsConfig.InsecureSkipVerify) require.Equal(t, c.TransactionIsolation, dup.TransactionIsolation) require.Equal(t, c.Charset, dup.Charset) } @@ -72,6 +75,7 @@ func TestDuplicate(t *testing.T) { require.Equal(t, 3306, dup.ImpliedKey.Port) require.Equal(t, "gromit", dup.User) require.Equal(t, "penguin", dup.Password) + require.Equal(t, c.tlsConfig, dup.tlsConfig) require.Equal(t, transactionIsolation, dup.TransactionIsolation) require.Equal(t, "utf8mb4", dup.Charset) } @@ -95,10 +99,17 @@ func TestGetDBUriWithTLSSetup(t *testing.T) { c.User = "gromit" c.Password = "penguin" c.Timeout = 1.2345 - c.tlsConfig = &tls.Config{} + c.tlsConfig = &tls.Config{ + ServerName: c.Key.Hostname, + } c.TransactionIsolation = transactionIsolation c.Charset = "utf8mb4_general_ci,utf8_general_ci,latin1" uri := c.GetDBUri("test") - require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4_general_ci,utf8_general_ci,latin1&tls=ghost&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri) + require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4_general_ci,utf8_general_ci,latin1&tls=ghost-myhost&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri) +} + +func TestGetDBTLSConfigKey(t *testing.T) { + configKey := GetDBTLSConfigKey("myhost") + require.Equal(t, "ghost-myhost", configKey) } diff --git a/go/mysql/utils.go b/go/mysql/utils.go index c69a3f255..2c860f106 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -128,8 +128,11 @@ func GetMasterConnectionConfigSafe(connectionConfig *ConnectionConfig, visitedKe if !masterKey.IsValid() { return connectionConfig, nil } - masterConfig = connectionConfig.Duplicate() - masterConfig.Key = *masterKey + + masterConfig = connectionConfig.DuplicateCredentials(*masterKey) + if err := masterConfig.RegisterTLSConfig(); err != nil { + return nil, err + } log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key) if visitedKeys.HasKey(masterConfig.Key) {