diff --git a/callapi/callapi.go b/callapi/callapi.go index 4140047c..df0f5cb8 100644 --- a/callapi/callapi.go +++ b/callapi/callapi.go @@ -130,6 +130,7 @@ type Client interface { // 为了解决processor和server循环依赖设计的接口 type WebSocketServerClienter interface { SendMessage(message map[string]interface{}) error + Close() error } // 根据action订阅handler处理api diff --git a/config/config.go b/config/config.go index bc1bb8cd..9a3f5ee7 100644 --- a/config/config.go +++ b/config/config.go @@ -37,6 +37,9 @@ type Settings struct { MasterID []string `yaml:"master_id,omitempty"` // 如果需要在群权限判断是管理员是,将user_id填入这里,master_id是一个文本数组 EnableWsServer bool `yaml:"enable_ws_server,omitempty"` //正向ws开关 WsServerToken string `yaml:"ws_server_token,omitempty"` //正向ws token + IdentifyFile bool `yaml:"identify_file"` // 域名校验文件 + Crt string `yaml:"crt"` + Key string `yaml:"key"` } // LoadConfig 从文件中加载配置并初始化单例配置 @@ -182,3 +185,39 @@ func GetWsServerToken() string { } return instance.Settings.WsServerToken } + +// 获取identify_file的值 +func GetIdentifyFile() bool { + mu.Lock() + defer mu.Unlock() + + if instance == nil { + log.Println("Warning: instance is nil when trying to get identify file name.") + return false + } + return instance.Settings.IdentifyFile +} + +// 获取crt路径 +func GetCrtPath() string { + mu.Lock() + defer mu.Unlock() + + if instance == nil { + log.Println("Warning: instance is nil when trying to get crt path.") + return "" + } + return instance.Settings.Crt +} + +// 获取key路径 +func GetKeyPath() string { + mu.Lock() + defer mu.Unlock() + + if instance == nil { + log.Println("Warning: instance is nil when trying to get key path.") + return "" + } + return instance.Settings.Key +} diff --git a/config_template.go b/config_template.go index 5d762815..09f0a64a 100644 --- a/config_template.go +++ b/config_template.go @@ -41,4 +41,8 @@ settings: master_id : ["1","2"] #群场景尚未开放获取管理员和列表能力,手动从日志中获取需要设置为管理,的user_id并填入(适用插件有权限判断场景) enable_ws_server: true #是否启用正向ws服务器 监听server_dir:port/ws ws_server_token : "12345" #正向ws的token 不启动正向ws可忽略 + ws_server_token : "12345" #正向ws的token 不启动正向ws可忽略 + identify_file: true #自动生成域名校验文件,在q.qq.com配置信息URL,在server_dir填入自己已备案域名,正确解析到机器人所在服务器ip地址,机器人即可发送链接 + crt: "" #证书路径 从你的域名服务商或云服务商申请签发SSL证书(qq要求SSL) + key: "" #密钥路径 Apache(crt文件、key文件)示例: "C:\\123.key" \需要双写成\\ ` diff --git a/config_template.yml b/config_template.yml index 2006abc0..7f8129e8 100644 --- a/config_template.yml +++ b/config_template.yml @@ -31,4 +31,8 @@ settings: ws_token: ["",""] #连接wss地址时服务器所需的token,如果是ws,可留空,按顺序一一对应 master_id : ["1","2"] #群场景尚未开放获取管理员和列表能力,手动从日志中获取需要设置为管理,的user_id并填入(适用插件有权限判断场景) enable_ws_server: true #是否启用正向ws服务器 监听server_dir:port/ws - ws_server_token : "12345" #正向ws的token 不启动正向ws可忽略 \ No newline at end of file + ws_server_token : "12345" #正向ws的token 不启动正向ws可忽略 + ws_server_token : "12345" #正向ws的token 不启动正向ws可忽略 + identify_file: true #自动生成域名校验文件,在q.qq.com配置信息URL,在server_dir填入自己已备案域名,正确解析到机器人所在服务器ip地址,机器人即可发送链接 + crt: "" #证书路径 从你的域名服务商或云服务商申请签发SSL证书(qq要求SSL) + key: "" #密钥路径 Apache(crt文件、key文件)示例: "C:\\123.key" \需要双写成\\ \ No newline at end of file diff --git a/echo/echo.go b/echo/echo.go index 218e0379..9f820e56 100644 --- a/echo/echo.go +++ b/echo/echo.go @@ -16,13 +16,13 @@ var globalEchoMapping = &EchoMapping{ msgIDMapping: make(map[string]string), } -func (e *EchoMapping) generateKey(appid string, s int64) string { +func (e *EchoMapping) GenerateKey(appid string, s int64) string { return appid + "_" + strconv.FormatInt(s, 10) } // 添加echo对应的类型 func AddMsgType(appid string, s int64, msgType string) { - key := globalEchoMapping.generateKey(appid, s) + key := globalEchoMapping.GenerateKey(appid, s) globalEchoMapping.mu.Lock() defer globalEchoMapping.mu.Unlock() globalEchoMapping.msgTypeMapping[key] = msgType @@ -30,7 +30,7 @@ func AddMsgType(appid string, s int64, msgType string) { // 添加echo对应的messageid func AddMsgID(appid string, s int64, msgID string) { - key := globalEchoMapping.generateKey(appid, s) + key := globalEchoMapping.GenerateKey(appid, s) globalEchoMapping.mu.Lock() defer globalEchoMapping.mu.Unlock() globalEchoMapping.msgIDMapping[key] = msgID diff --git a/gensokyo.db b/gensokyo.db new file mode 100644 index 00000000..5ab9bd33 Binary files /dev/null and b/gensokyo.db differ diff --git a/go.mod b/go.mod index 59a71915..565bbfda 100644 --- a/go.mod +++ b/go.mod @@ -45,4 +45,6 @@ require ( golang.org/x/sys v0.8.0 // indirect golang.org/x/text v0.9.0 // indirect google.golang.org/protobuf v1.30.0 // indirect + mvdan.cc/xurls v1.1.0 + mvdan.cc/xurls/v2 v2.5.0 // indirect ) diff --git a/go.sum b/go.sum index b72202cc..977f04d5 100644 --- a/go.sum +++ b/go.sum @@ -193,4 +193,8 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +mvdan.cc/xurls v1.1.0 h1:kj0j2lonKseISJCiq1Tfk+iTv65dDGCl0rTbanXJGGc= +mvdan.cc/xurls v1.1.0/go.mod h1:TNWuhvo+IqbUCmtUIb/3LJSQdrzel8loVpgFm0HikbI= +mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8= +mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/handlers/get_group_info.go b/handlers/get_group_info.go index 4e89d21d..86b3d383 100644 --- a/handlers/get_group_info.go +++ b/handlers/get_group_info.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "log" - "strconv" "github.com/hoshinonyaruko/gensokyo/callapi" "github.com/hoshinonyaruko/gensokyo/idmap" @@ -28,11 +27,13 @@ type OnebotGroupInfo struct { } func ConvertGuildToGroupInfo(guild *dto.Guild, GroupId string) *OnebotGroupInfo { - groupidstr, err := strconv.ParseInt(GroupId, 10, 64) + // 使用idmap.StoreIDv2映射GroupId到一个int64的值 + groupid64, err := idmap.StoreIDv2(GroupId) if err != nil { - log.Printf("groupidstr: %v", err) + log.Printf("Error storing GroupID: %v", err) return nil } + ts, err := guild.JoinedAt.Time() if err != nil { log.Printf("转换JoinedAt失败: %v", err) @@ -41,7 +42,7 @@ func ConvertGuildToGroupInfo(guild *dto.Guild, GroupId string) *OnebotGroupInfo groupCreateTime := uint32(ts.Unix()) return &OnebotGroupInfo{ - GroupID: groupidstr, + GroupID: groupid64, GroupName: guild.Name, GroupMemo: guild.Desc, GroupCreateTime: groupCreateTime, diff --git a/handlers/message_parser.go b/handlers/message_parser.go index 00609383..074cf026 100644 --- a/handlers/message_parser.go +++ b/handlers/message_parser.go @@ -10,7 +10,9 @@ import ( "github.com/hoshinonyaruko/gensokyo/callapi" "github.com/hoshinonyaruko/gensokyo/idmap" + "github.com/hoshinonyaruko/gensokyo/url" "github.com/tencent-connect/botgo/dto" + "mvdan.cc/xurls" //xurls是一个从文本提取url的库 适用于多种场景 ) var BotID string @@ -58,6 +60,7 @@ func SendResponse(client callapi.Client, err error, message *callapi.ActionMessa return nil } +// 信息处理函数 func parseMessageContent(paramsMessage callapi.ParamsContent) (string, map[string][]string) { messageText := "" @@ -156,11 +159,11 @@ func parseMessageContent(paramsMessage callapi.ParamsContent) (string, map[strin return messageText, foundItems } +// at处理和链接处理 func transformMessageText(messageText string) string { // 使用正则表达式来查找所有[CQ:at,qq=数字]的模式 re := regexp.MustCompile(`\[CQ:at,qq=(\d+)\]`) - // 使用正则表达式来替换找到的模式为<@!数字> - return re.ReplaceAllStringFunc(messageText, func(m string) string { + messageText = re.ReplaceAllStringFunc(messageText, func(m string) string { submatches := re.FindStringSubmatch(m) if len(submatches) > 1 { realUserID, err := idmap.RetrieveRowByIDv2(submatches[1]) @@ -172,6 +175,14 @@ func transformMessageText(messageText string) string { } return m }) + + // 使用xurls来查找和替换所有的URL + messageText = xurls.Relaxed.ReplaceAllStringFunc(messageText, func(originalURL string) string { + shortURL := url.GenerateShortURL(originalURL) + // 使用getBaseURL函数来获取baseUrl并与shortURL组合 + return url.GetBaseURL() + "/url/" + shortURL + }) + return messageText } // 处理at和其他定形文到onebotv11格式(cq码) diff --git a/handlers/send_guild_channel_msg.go b/handlers/send_guild_channel_msg.go index fc61b4d7..5b6f46ab 100644 --- a/handlers/send_guild_channel_msg.go +++ b/handlers/send_guild_channel_msg.go @@ -56,6 +56,11 @@ func handleSendGuildChannelMsg(client callapi.Client, api openapi.OpenAPI, apiv2 messageID = echo.GetMsgIDByKey(echoStr) log.Println("echo取频道发信息对应的message_id:", messageID) } + // 如果messageID为空,通过函数获取 + if messageID == "" { + messageID = GetMessageIDByUseridOrGroupid(config.GetAppIDStr(), channelID) + log.Println("通过GetMessageIDByUserid函数获取的message_id:", messageID) + } log.Println("频道发信息messageText:", messageText) log.Println("foundItems:", foundItems) var err error diff --git a/handlers/send_msg.go b/handlers/send_msg.go index 734a4673..38716dbb 100644 --- a/handlers/send_msg.go +++ b/handlers/send_msg.go @@ -169,6 +169,11 @@ func handleSendMsg(client callapi.Client, api openapi.OpenAPI, apiv2 openapi.Ope messageID = echo.GetMsgIDByKey(echoStr) log.Println("echo取私聊发信息对应的message_id:", messageID) } + // 如果messageID为空,通过函数获取 + if messageID == "" { + messageID = GetMessageIDByUseridOrGroupid(config.GetAppIDStr(), UserID) + log.Println("通过GetMessageIDByUserid函数获取的message_id:", messageID) + } log.Println("私聊发信息messageText:", messageText) log.Println("foundItems:", foundItems) diff --git a/handlers/send_private_msg.go b/handlers/send_private_msg.go index 91e32218..de98dda0 100644 --- a/handlers/send_private_msg.go +++ b/handlers/send_private_msg.go @@ -47,6 +47,12 @@ func handleSendPrivateMsg(client callapi.Client, api openapi.OpenAPI, apiv2 open messageID = echo.GetMsgIDByKey(echoStr) log.Println("echo取私聊发信息对应的message_id:", messageID) } + // 如果messageID仍然为空,尝试使用config.GetAppID和UserID的组合来获取messageID + // 如果messageID为空,通过函数获取 + if messageID == "" { + messageID = GetMessageIDByUseridOrGroupid(config.GetAppIDStr(), UserID) + log.Println("通过GetMessageIDByUserid函数获取的message_id:", messageID) + } log.Println("私聊发信息messageText:", messageText) log.Println("foundItems:", foundItems) @@ -66,6 +72,8 @@ func handleSendPrivateMsg(client callapi.Client, api openapi.OpenAPI, apiv2 open if err != nil { log.Printf("发送文本私聊信息失败: %v", err) } + //发送成功回执 + SendResponse(client, err, &message) } // 遍历 foundItems 并发送每种信息 @@ -85,6 +93,8 @@ func handleSendPrivateMsg(client callapi.Client, api openapi.OpenAPI, apiv2 open if err != nil { log.Printf("发送 %s 私聊信息失败: %v", key, err) } + //发送成功回执 + SendResponse(client, err, &message) } case "guild_private": //当收到发私信调用 并且来源是频道 diff --git a/idmap/cMapping.go b/idmap/cMapping.go deleted file mode 100644 index 7975f940..00000000 --- a/idmap/cMapping.go +++ /dev/null @@ -1,40 +0,0 @@ -// 访问idmaps -package idmap - -import ( - "encoding/json" - "fmt" - "net/http" -) - -type CompatibilityMapping struct{} - -// SetID 对外部API进行写操作,返回映射的值 -func (c *CompatibilityMapping) SetID(id string) (int, error) { - return c.getIDByType(id, "1") -} - -// GetOriginalID 使用映射值获取原始值 -func (c *CompatibilityMapping) GetOriginalID(mappedID string) (int, error) { - return c.getIDByType(mappedID, "2") -} - -func (c *CompatibilityMapping) getIDByType(id, typeVal string) (int, error) { - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:15817/getid?id=%s&type=%s", id, typeVal)) - if err != nil { - return 0, err - } - defer resp.Body.Close() - - var result map[string]int - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return 0, err - } - - value, ok := result["row"] - if !ok { - return 0, fmt.Errorf("row not found in the response") - } - - return value, nil -} diff --git a/main.go b/main.go index 727f508e..969e9d9b 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log" + "net/http" "os" "os/signal" "syscall" @@ -16,6 +17,7 @@ import ( "github.com/hoshinonyaruko/gensokyo/handlers" "github.com/hoshinonyaruko/gensokyo/idmap" "github.com/hoshinonyaruko/gensokyo/server" + "github.com/hoshinonyaruko/gensokyo/url" "github.com/hoshinonyaruko/gensokyo/wsclient" "github.com/gin-gonic/gin" @@ -169,13 +171,49 @@ func main() { //是否启动服务器 shouldStartServer := !conf.Settings.Lotus || conf.Settings.EnableWsServer //如果连接到其他gensokyo,则不需要启动服务器 + var httpServer *http.Server if shouldStartServer { r := gin.Default() r.GET("/getid", server.GetIDHandler) r.POST("/uploadpic", server.UploadBase64ImageHandler(rateLimiter)) r.Static("/channel_temp", "./channel_temp") r.GET("/ws", server.WsHandlerWithDependencies(api, apiV2, p)) - r.Run("0.0.0.0:" + conf.Settings.Port) // 监听0.0.0.0地址的Port端口 + r.POST("/url", url.CreateShortURLHandler) + r.GET("/url/:shortURL", url.RedirectFromShortURLHandler) + if config.GetIdentifyFile() { + appIDStr := config.GetAppIDStr() + fileName := appIDStr + ".json" + r.GET("/"+fileName, func(c *gin.Context) { + content := fmt.Sprintf(`{"bot_appid":%d}`, config.GetAppID()) + c.Header("Content-Type", "application/json") + c.String(200, content) + }) + } + // 创建一个http.Server实例 + httpServer = &http.Server{ + Addr: "0.0.0.0:" + conf.Settings.Port, + Handler: r, + } + // 在一个新的goroutine中启动Gin服务器 + go func() { + if conf.Settings.Port == "443" { + // 使用HTTPS + crtPath := config.GetCrtPath() + keyPath := config.GetKeyPath() + if crtPath == "" || keyPath == "" { + log.Fatalf("crt or key path is missing for HTTPS") + return + } + if err := httpServer.ListenAndServeTLS(crtPath, keyPath); err != nil && err != http.ErrServerClosed { + log.Fatalf("listen (HTTPS): %s\n", err) + } + } else { + // 使用HTTP + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("listen: %s\n", err) + } + } + }() } // 使用通道来等待信号 @@ -193,6 +231,24 @@ func main() { fmt.Printf("Error closing WebSocket connection: %v\n", err) } } + + // 关闭BoltDB数据库 + url.CloseDB() + idmap.CloseDB() + + // 在关闭WebSocket客户端之前 + for _, wsClient := range p.WsServerClients { + if err := wsClient.Close(); err != nil { + log.Printf("Error closing WebSocket server client: %v\n", err) + } + } + + // 使用一个5秒的超时优雅地关闭Gin服务器 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := httpServer.Shutdown(ctx); err != nil { + log.Fatal("Server forced to shutdown:", err) + } } // ReadyHandler 自定义 ReadyHandler 感知连接成功事件 diff --git a/server/wsserver.go b/server/wsserver.go index fd714c72..55b89bd3 100644 --- a/server/wsserver.go +++ b/server/wsserver.go @@ -105,6 +105,17 @@ func wsHandler(api openapi.OpenAPI, apiV2 openapi.OpenAPI, p *Processor.Processo log.Printf("Error sending connection success message: %v\n", err) } + // 在defer语句之前运行 + defer func() { + // 移除客户端从WsServerClients + for i, wsClient := range p.WsServerClients { + if wsClient == client { + p.WsServerClients = append(p.WsServerClients[:i], p.WsServerClients[i+1:]...) + break + } + } + }() + //退出时候的清理 defer conn.Close() for { @@ -142,3 +153,7 @@ func (c *WebSocketServerClient) SendMessage(message map[string]interface{}) erro } return c.Conn.WriteMessage(websocket.TextMessage, msgBytes) } + +func (client *WebSocketServerClient) Close() error { + return client.Conn.Close() +} diff --git a/url/shorturl.go b/url/shorturl.go new file mode 100644 index 00000000..83105a50 --- /dev/null +++ b/url/shorturl.go @@ -0,0 +1,349 @@ +package url + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "log" + "math/rand" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/boltdb/bolt" + "github.com/gin-gonic/gin" + "github.com/hoshinonyaruko/gensokyo/config" +) + +const ( + bucketName = "shortURLs" +) + +var ( + db *bolt.DB +) + +const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +const length = 6 + +func generateRandomString() string { + rand.Seed(time.Now().UnixNano()) + result := make([]byte, length) + for i := range result { + result[i] = charset[rand.Intn(len(charset))] + } + return string(result) +} + +func generateHashedString(url string) string { + hash := sha256.Sum256([]byte(url)) + return hex.EncodeToString(hash[:3]) // 取前3个字节,得到6个字符的16进制表示 +} + +func init() { + var err error + db, err = bolt.Open("gensokyo.db", 0600, nil) + if err != nil { + panic(err) + } + + // Ensure bucket exists + err = db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucketIfNotExists([]byte(bucketName)) + if err != nil { + return fmt.Errorf("failed to create or get the bucket: %v", err) + } + return nil + }) + if err != nil { + panic(fmt.Sprintf("Error initializing the database: %v", err)) + } +} + +// 验证链接是否合法 +func isValidURL(toTest string) bool { + parsedURL, err := url.ParseRequestURI(toTest) + if err != nil { + return false + } + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return false + } + + // 阻止localhost和本地IP地址 + host := parsedURL.Hostname() + localHostnames := []string{"localhost", "127.0.0.1", "::1"} + for _, localHost := range localHostnames { + if host == localHost { + return false + } + } + + // 检查是否是私有IP地址 + return !isPrivateIP(host) +} + +// 检查是否是私有IP地址 +func isPrivateIP(ipStr string) bool { + privateIPBlocks := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + } + ip := net.ParseIP(ipStr) + for _, block := range privateIPBlocks { + _, ipnet, err := net.ParseCIDR(block) + if err != nil { + continue + } + if ipnet.Contains(ip) { + return true + } + } + return false +} + +// 检查和解码可能的Base64编码的URL +func decodeBase64IfNeeded(input string) string { + if len(input)%4 == 0 { // 一个简单的检查来看它是否可能是Base64 + decoded, err := base64.StdEncoding.DecodeString(input) + if err == nil { + return string(decoded) + } + } + return input +} + +// 生成短链接 +func GenerateShortURL(longURL string) string { + if config.GetLotusValue() { + serverDir := config.GetServer_dir() + portValue := config.GetPortValue() + url := fmt.Sprintf("http://%s:%s/url", serverDir, portValue) + + payload := map[string]string{"longURL": longURL} + jsonPayload, err := json.Marshal(payload) + if err != nil { + log.Printf("Error marshaling payload: %v", err) + return "" + } + + resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonPayload)) + if err != nil { + log.Printf("Error while generating short URL: %v", err) + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Printf("Received non-200 status code: %d from server: %v", resp.StatusCode, url) + return "" + } + + var response map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&response) + if err != nil { + log.Println("Error decoding response") + return "" + } + + shortURL, ok := response["shortURL"].(string) + if !ok { + log.Println("shortURL not found or not a string in the response") + return "" + } + + return shortURL + + } else { + shortURL := generateHashedString(longURL) + + exists, err := existsInDB(shortURL) + if err != nil { + log.Printf("Error checking if shortURL exists in DB: %v", err) + return "" // 如果有错误, 返回空的短链接 + } + if exists { + for { + shortURL = generateRandomString() + exists, err := existsInDB(shortURL) + if err != nil { + log.Printf("Error checking if shortURL exists in DB: %v", err) + return "" // 如果有错误, 返回空的短链接 + } + if !exists { + break + } + } + } + + // 存储短URL和对应的长URL + err = storeURL(shortURL, longURL) + if err != nil { + log.Printf("Error storing URL in DB: %v", err) + return "" + } + + return shortURL + } +} + +func existsInDB(shortURL string) (bool, error) { + exists := false + err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(bucketName)) + v := b.Get([]byte(shortURL)) + if v != nil { + exists = true + } + return nil + }) + if err != nil { + log.Printf("Error accessing the database: %v", err) // 记录错误 + return false, err + } + return exists, nil +} + +// 从数据库获取短链接 +func getLongURLFromDB(shortURL string) (string, error) { + if config.GetLotusValue() { + serverDir := config.GetServer_dir() + portValue := config.GetPortValue() + url := fmt.Sprintf("http://%s:%s/url/%s", serverDir, portValue, shortURL) + + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Printf("Received non-200 status code: %d while fetching long URL from server: %v", resp.StatusCode, url) + return "", fmt.Errorf("error fetching long URL from remote server with status code: %d", resp.StatusCode) + } + + var response map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&response) + if err != nil { + return "", fmt.Errorf("error decoding response from server") + } + return response["longURL"].(string), nil + } else { + var longURL string + err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(bucketName)) + v := b.Get([]byte(shortURL)) + if v == nil { + return fmt.Errorf("URL not found") + } + longURL = string(v) + return nil + }) + return longURL, err + } +} + +// storeURL 存储长URL和对应的短URL +func storeURL(shortURL, longURL string) error { + return db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(bucketName)) + return b.Put([]byte(shortURL), []byte(longURL)) + }) +} + +// 安全性检查 +func isMalicious(decoded string) bool { + lowerDecoded := strings.ToLower(decoded) + + // 检查javascript协议,用于防止XSS + if strings.HasPrefix(lowerDecoded, "javascript:") { + return true + } + + // 检查data协议,可能被用于各种攻击 + if strings.HasPrefix(lowerDecoded, "data:") { + return true + } + + // 检查常见的HTML标签,这可能用于指示XSS攻击 + for _, tag := range []string{"<script", "<img", "<iframe", "<link", "<style"} { + if strings.Contains(lowerDecoded, tag) { + return true + } + } + + return false +} + +// 短链接服务handler +func CreateShortURLHandler(c *gin.Context) { + rawURL := c.PostForm("url") + longURL := decodeBase64IfNeeded(rawURL) + + if longURL == "" || isMalicious(longURL) || !isValidURL(longURL) { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid URL"}) + return + } + + // Generate short URL + shortURL := GenerateShortURL(longURL) + + // Construct baseUrl + serverDir := config.GetServer_dir() + portValue := config.GetPortValue() + protocol := "http" + if portValue == "443" { + protocol = "https" + } + baseUrl := protocol + "://" + serverDir + if portValue != "80" && portValue != "443" && portValue != "" { + baseUrl += ":" + portValue + } + + c.JSON(http.StatusOK, gin.H{"shortURL": baseUrl + "/url/" + shortURL}) +} + +// 短链接baseurl +func GetBaseURL() string { + serverDir := config.GetServer_dir() + portValue := config.GetPortValue() + protocol := "http" + if portValue == "443" { + protocol = "https" + } + baseUrl := protocol + "://" + serverDir + if portValue != "80" && portValue != "443" && portValue != "" { + baseUrl += ":" + portValue + } + return baseUrl +} + +// RedirectFromShortURLHandler +func RedirectFromShortURLHandler(c *gin.Context) { + shortURL := c.Param("shortURL") + + // Fetch from Bolt DB + longURL, err := getLongURLFromDB(shortURL) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "URL not found"}) + return + } + + // Ensure longURL has a scheme (http or https) + if !strings.HasPrefix(longURL, "http://") && !strings.HasPrefix(longURL, "https://") { + // Add default scheme if missing + longURL = "http://" + longURL + } + + c.Redirect(http.StatusMovedPermanently, longURL) +} + +func CloseDB() { + db.Close() +}