-
-
Notifications
You must be signed in to change notification settings - Fork 516
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
442 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
package database | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"io" | ||
"mime/multipart" | ||
"os" | ||
"path/filepath" | ||
"s-ui/cmd/migration" | ||
"s-ui/config" | ||
"s-ui/database/model" | ||
"s-ui/logger" | ||
"s-ui/util/common" | ||
"strings" | ||
"syscall" | ||
"time" | ||
|
||
"gorm.io/driver/sqlite" | ||
"gorm.io/gorm" | ||
) | ||
|
||
func GetDb(exclude string) ([]byte, error) { | ||
exclude_changes, exclude_stats := false, false | ||
for _, table := range strings.Split(exclude, ",") { | ||
if table == "changes" { | ||
exclude_changes = true | ||
} else if table == "stats" { | ||
exclude_stats = true | ||
} | ||
} | ||
|
||
dir, err := filepath.Abs(filepath.Dir(os.Args[0])) | ||
if err != nil { | ||
return nil, err | ||
} | ||
dbPath := dir + config.GetName() + "_" + time.Now().Format("20060102-200203") + ".db" | ||
|
||
backupDb, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
err = backupDb.AutoMigrate( | ||
&model.Setting{}, | ||
&model.Tls{}, | ||
&model.Inbound{}, | ||
&model.Outbound{}, | ||
&model.Endpoint{}, | ||
&model.User{}, | ||
&model.Stats{}, | ||
&model.Client{}, | ||
&model.Changes{}, | ||
) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var settings []model.Setting | ||
var tls []model.Tls | ||
var inbound []model.Inbound | ||
var outbound []model.Outbound | ||
var endpoint []model.Endpoint | ||
var users []model.User | ||
var clients []model.Client | ||
var stats []model.Stats | ||
var changes []model.Changes | ||
|
||
// Perform scans and handle errors | ||
if err := db.Model(&model.Setting{}).Scan(&settings).Error; err != nil { | ||
return nil, err | ||
} | ||
if err := db.Model(&model.Tls{}).Scan(&tls).Error; err != nil { | ||
return nil, err | ||
} | ||
if err := db.Model(&model.Inbound{}).Scan(&inbound).Error; err != nil { | ||
return nil, err | ||
} | ||
if err := db.Model(&model.Outbound{}).Scan(&outbound).Error; err != nil { | ||
return nil, err | ||
} | ||
if err := db.Model(&model.Endpoint{}).Scan(&endpoint).Error; err != nil { | ||
return nil, err | ||
} | ||
if err := db.Model(&model.User{}).Scan(&users).Error; err != nil { | ||
return nil, err | ||
} | ||
if err := db.Model(&model.Client{}).Scan(&clients).Error; err != nil { | ||
return nil, err | ||
} | ||
|
||
// Save each model | ||
for _, mdl := range []interface{}{settings, tls, inbound, outbound, endpoint, users, clients} { | ||
if err := backupDb.Save(mdl).Error; err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
if !exclude_stats { | ||
if err := db.Model(&model.Stats{}).Scan(&stats).Error; err != nil { | ||
return nil, err | ||
} | ||
if err := backupDb.Save(stats).Error; err != nil { | ||
return nil, err | ||
} | ||
} | ||
if !exclude_changes { | ||
if err := db.Model(&model.Changes{}).Scan(&changes).Error; err != nil { | ||
return nil, err | ||
} | ||
if err := backupDb.Save(changes).Error; err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
// Update WAL | ||
err = backupDb.Exec("PRAGMA wal_checkpoint;").Error | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
bdb, _ := backupDb.DB() | ||
bdb.Close() | ||
|
||
// Open the file for reading | ||
file, err := os.Open(dbPath) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer file.Close() | ||
defer os.Remove(dbPath) | ||
|
||
// Read the file contents | ||
fileContents, err := io.ReadAll(file) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return fileContents, nil | ||
} | ||
|
||
func ImportDB(file multipart.File) error { | ||
// Check if the file is a SQLite database | ||
isValidDb, err := IsSQLiteDB(file) | ||
if err != nil { | ||
return common.NewErrorf("Error checking db file format: %v", err) | ||
} | ||
if !isValidDb { | ||
return common.NewError("Invalid db file format") | ||
} | ||
|
||
// Reset the file reader to the beginning | ||
_, err = file.Seek(0, 0) | ||
if err != nil { | ||
return common.NewErrorf("Error resetting file reader: %v", err) | ||
} | ||
|
||
// Save the file as temporary file | ||
tempPath := fmt.Sprintf("%s.temp", config.GetDBPath()) | ||
// Remove the existing fallback file (if any) before creating one | ||
_, err = os.Stat(tempPath) | ||
if err == nil { | ||
errRemove := os.Remove(tempPath) | ||
if errRemove != nil { | ||
return common.NewErrorf("Error removing existing temporary db file: %v", errRemove) | ||
} | ||
} | ||
// Create the temporary file | ||
tempFile, err := os.Create(tempPath) | ||
if err != nil { | ||
return common.NewErrorf("Error creating temporary db file: %v", err) | ||
} | ||
defer tempFile.Close() | ||
|
||
// Remove temp file before returning | ||
defer os.Remove(tempPath) | ||
|
||
// Close old DB | ||
old_db, _ := db.DB() | ||
old_db.Close() | ||
|
||
// Save uploaded file to temporary file | ||
_, err = io.Copy(tempFile, file) | ||
if err != nil { | ||
return common.NewErrorf("Error saving db: %v", err) | ||
} | ||
|
||
// Check if we can init db or not | ||
newDb, err := gorm.Open(sqlite.Open(tempPath), &gorm.Config{}) | ||
if err != nil { | ||
return common.NewErrorf("Error checking db: %v", err) | ||
} | ||
newDb_db, _ := newDb.DB() | ||
newDb_db.Close() | ||
|
||
// Backup the current database for fallback | ||
fallbackPath := fmt.Sprintf("%s.backup", config.GetDBPath()) | ||
// Remove the existing fallback file (if any) | ||
_, err = os.Stat(fallbackPath) | ||
if err == nil { | ||
errRemove := os.Remove(fallbackPath) | ||
if errRemove != nil { | ||
return common.NewErrorf("Error removing existing fallback db file: %v", errRemove) | ||
} | ||
} | ||
// Move the current database to the fallback location | ||
err = os.Rename(config.GetDBPath(), fallbackPath) | ||
if err != nil { | ||
return common.NewErrorf("Error backing up temporary db file: %v", err) | ||
} | ||
|
||
// Remove the temporary file before returning | ||
defer os.Remove(fallbackPath) | ||
|
||
// Move temp to DB path | ||
err = os.Rename(tempPath, config.GetDBPath()) | ||
if err != nil { | ||
errRename := os.Rename(fallbackPath, config.GetDBPath()) | ||
if errRename != nil { | ||
return common.NewErrorf("Error moving db file and restoring fallback: %v", errRename) | ||
} | ||
return common.NewErrorf("Error moving db file: %v", err) | ||
} | ||
|
||
// Migrate DB | ||
migration.MigrateDb() | ||
err = InitDB(config.GetDBPath()) | ||
if err != nil { | ||
errRename := os.Rename(fallbackPath, config.GetDBPath()) | ||
if errRename != nil { | ||
return common.NewErrorf("Error migrating db and restoring fallback: %v", errRename) | ||
} | ||
return common.NewErrorf("Error migrating db: %v", err) | ||
} | ||
|
||
// Restart app | ||
err = SendSighup() | ||
if err != nil { | ||
return common.NewErrorf("Error restarting app: %v", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func IsSQLiteDB(file io.Reader) (bool, error) { | ||
signature := []byte("SQLite format 3\x00") | ||
buf := make([]byte, len(signature)) | ||
_, err := file.Read(buf) | ||
if err != nil { | ||
return false, err | ||
} | ||
return bytes.Equal(buf, signature), nil | ||
} | ||
|
||
func SendSighup() error { | ||
// Get the current process | ||
process, err := os.FindProcess(os.Getpid()) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// Send SIGHUP to the current process | ||
go func() { | ||
time.Sleep(3 * time.Second) | ||
err := process.Signal(syscall.SIGHUP) | ||
if err != nil { | ||
logger.Error("send signal SIGHUP failed:", err) | ||
} | ||
}() | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.