Skip to content

Commit

Permalink
allow pgreba to start when postgres is down (#20)
Browse files Browse the repository at this point in the history
* defer panic

* refactored Close and added getDB method to lazily eval db connection

* use getDB

* IsInRecovery returns false in error state

* gofmt

* remove unneeded extra space
  • Loading branch information
schinns authored Dec 22, 2020
1 parent cc023df commit 2a84e8f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 21 deletions.
95 changes: 78 additions & 17 deletions data_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,40 @@ type ReplicationDataSource interface {

// Postgres connection impl of replication data source.
type pgDataSource struct {
DB *sqlx.DB
cfg *config.Config
db *sqlx.DB
cfg *config.Config
dbMutex sync.Mutex
}

func NewPgReplicationDataSource(config *config.Config) (ReplicationDataSource, error) {
db, err := sqlConnect(fmt.Sprintf("host=%s port=%s database=%s user=%s sslmode=%s binary_parameters=%s", config.Host, config.Port, config.Database, config.User, config.Sslmode, config.BinaryParameters))
func NewPgReplicationDataSource(config *config.Config) ReplicationDataSource {
return &pgDataSource{cfg: config, dbMutex: sync.Mutex{}}
}

func (ds *pgDataSource) Close() error {
ds.dbMutex.Lock()
defer ds.dbMutex.Unlock()
if ds.db == nil {
return nil
}
err := ds.db.Close()
ds.db = nil
return err
}

func (ds *pgDataSource) getDB() (*sqlx.DB, error) {
ds.dbMutex.Lock()
defer ds.dbMutex.Unlock()
if ds.db != nil {
return ds.db, nil
}
db, err := sqlConnect(fmt.Sprintf("host=%s port=%s database=%s user=%s sslmode=%s binary_parameters=%s", ds.cfg.Host, ds.cfg.Port, ds.cfg.Database, ds.cfg.User, ds.cfg.Sslmode, ds.cfg.BinaryParameters))
if err != nil {
fmt.Println("Error creating a connection pool.")
return nil, err
}
ds.db = db

return &pgDataSource{DB: db, cfg: config}, nil
}

func (ds *pgDataSource) Close() error {
return ds.DB.Close()
return db, nil
}

func (ds *pgDataSource) GetNodeInfo() (*NodeInfo, error) {
Expand Down Expand Up @@ -177,7 +196,12 @@ FROM
FROM pg_catalog.pg_stat_get_wal_senders() w,
pg_catalog.pg_stat_get_activity(pid)) AS ri
`
rows, err := ds.DB.Queryx(sql)
db, dbErr := ds.getDB()
if dbErr != nil {
return nil, dbErr
}

rows, err := db.Queryx(sql)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -249,8 +273,13 @@ FROM
}

func (ds *pgDataSource) getUpstreamConnInfo() (string, error) {
db, dbErr := ds.getDB()
if dbErr != nil {
return "", dbErr
}

stats := PgStatWalReceiver{}
err := ds.DB.Get(&stats, "select * from pg_stat_wal_receiver;")
err := db.Get(&stats, "select * from pg_stat_wal_receiver;")
if err != nil {
return "", err
}
Expand Down Expand Up @@ -295,8 +324,13 @@ func (ds *pgDataSource) getPgCurrentWalLsn(role string) (string, error) {
}
return pgCurrentWalLsn, nil
} else {
db, dbErr := ds.getDB()
if dbErr != nil {
return "", dbErr
}

var pgCurrentWalLsn string
err := ds.DB.Get(&pgCurrentWalLsn, "select pg_current_wal_lsn()")
err := db.Get(&pgCurrentWalLsn, "select pg_current_wal_lsn()")
if err != nil {
return "", err
}
Expand All @@ -305,42 +339,69 @@ func (ds *pgDataSource) getPgCurrentWalLsn(role string) (string, error) {
}

func (ds *pgDataSource) getPgLastWalReplayLsn() (string, error) {
db, dbErr := ds.getDB()
if dbErr != nil {
return "", dbErr
}

pgLastWalLsn := null.String{}
err := ds.DB.Get(&pgLastWalLsn, "select pg_last_wal_replay_lsn()")
err := db.Get(&pgLastWalLsn, "select pg_last_wal_replay_lsn()")
if err != nil {
return "", err
}
return pgLastWalLsn.String, nil
}

func (ds *pgDataSource) getPgWalLsnDiff(currentLsn string, lastLsn string) (int64, error) {
db, dbErr := ds.getDB()
if dbErr != nil {
return 0, dbErr
}

var byteLag int64

query := fmt.Sprintf("select pg_wal_lsn_diff('%s', '%s')", currentLsn, lastLsn)
err := ds.DB.Get(&byteLag, query)

err := db.Get(&byteLag, query)
if err != nil {
return 0, err
}
return byteLag, nil
}

func (ds *pgDataSource) IsInRecovery() (bool, error) {
db, dbErr := ds.getDB()
if dbErr != nil {
return false, dbErr
}

var isInRecovery bool
err := ds.DB.Get(&isInRecovery, "select pg_catalog.pg_is_in_recovery()")

err := db.Get(&isInRecovery, "select pg_catalog.pg_is_in_recovery()")
return isInRecovery, err
}

func (ds *pgDataSource) GetPgStatReplication() ([]*PgStatReplication, error) {
stats := []*PgStatReplication{}
// TODO: Make this only grab required fields.
err := ds.DB.Select(&stats, "select * from pg_stat_replication")
db, dbErr := ds.getDB()
if dbErr != nil {
return nil, dbErr
}

err := db.Select(&stats, "select * from pg_stat_replication")
return stats, err
}

func (ds *pgDataSource) GetPgReplicationSlots() ([]*PgReplicationSlot, error) {
slots := []*PgReplicationSlot{}
// TODO: Make this only grab required fields.
err := ds.DB.Select(&slots, "select * from pg_replication_slots")
db, dbErr := ds.getDB()
if dbErr != nil {
return nil, dbErr
}

err := db.Select(&slots, "select * from pg_replication_slots")
return slots, err
}

Expand Down
6 changes: 2 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,8 @@ func main() {
if err != nil {
panic(err)
}
ds, err := NewPgReplicationDataSource(cfg)
if err != nil {
panic(err)
}

ds := NewPgReplicationDataSource(cfg)
defer ds.Close()

// Wrap the data source in a caching layer to prevent
Expand Down

0 comments on commit 2a84e8f

Please sign in to comment.