diff --git a/Dockerfile.proxy b/Dockerfile.proxy new file mode 100644 index 0000000000..c456908dd1 --- /dev/null +++ b/Dockerfile.proxy @@ -0,0 +1,31 @@ +ARG ARCH +FROM ${ARCH}ossrs/srs:ubuntu20 AS build + +COPY ./proxy /proxy + +WORKDIR /proxy + +RUN make clean && make + +############################################################ +# dist +############################################################ +FROM ${ARCH}ubuntu:focal AS dist + +WORKDIR /proxy + +COPY --from=build /proxy/srs-proxy /proxy/ +COPY ./trunk/research /proxy/static + +ENV PROXY_STATIC_FILES="/proxy/static" +ENV PROXY_LOAD_BALANCER_TYPE="memory" +ENV PROXY_RTMP_SERVER=1935 +ENV PROXY_HTTP_SERVER=8080 +ENV PROXY_HTTP_API=1985 +ENV PROXY_WEBRTC_SERVER=8000 +ENV PROXY_SRT_SERVER=10080 +ENV PROXY_SYSTEM_API=12025 + +EXPOSE 1935 8080 1985 12025 8000/udp 10080/udp + +CMD ["./srs-proxy"] \ No newline at end of file diff --git a/proxy/api.go b/proxy/api.go index 04baa92526..f7f4681a1c 100644 --- a/proxy/api.go +++ b/proxy/api.go @@ -82,7 +82,7 @@ func (v *srsHTTPAPIServer) Run(ctx context.Context) error { logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr) mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) { if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil { - apiError(ctx, w, r, err) + apiError(ctx, w, r, err, http.StatusInternalServerError) } }) @@ -90,10 +90,15 @@ func (v *srsHTTPAPIServer) Run(ctx context.Context) error { logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr) mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) { if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil { - apiError(ctx, w, r, err) + apiError(ctx, w, r, err, http.StatusInternalServerError) } }) + logger.Df(ctx, "Proxy /api/ to srs") + mux.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) { + srsLoadBalancer.ProxyHTTPAPI(ctx, w, r) + }) + // Run HTTP API server. v.wg.Add(1) go func() { @@ -239,7 +244,7 @@ func (v *systemAPI) Run(ctx context.Context) error { logger.Df(ctx, "Register SRS media server, %+v", server) return nil }(); err != nil { - apiError(ctx, w, r, err) + apiError(ctx, w, r, err, http.StatusInternalServerError) } type Response struct { diff --git a/proxy/http.go b/proxy/http.go index f02af02a30..92f5942f5f 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -198,7 +198,7 @@ func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) ctx := logger.WithContext(v.ctx) if err := v.serve(ctx, w, r); err != nil { - apiError(ctx, w, r, err) + apiError(ctx, w, r, err, http.StatusInternalServerError) } else { logger.Df(ctx, "HTTP client done") } @@ -318,7 +318,7 @@ func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() if err := v.serve(v.ctx, w, r); err != nil { - apiError(v.ctx, w, r, err) + apiError(v.ctx, w, r, err, http.StatusInternalServerError) } else { logger.Df(v.ctx, "HLS client %v for %v with %v done", v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path) diff --git a/proxy/srs-api-proxy.go b/proxy/srs-api-proxy.go new file mode 100644 index 0000000000..0968a7e4b8 --- /dev/null +++ b/proxy/srs-api-proxy.go @@ -0,0 +1,308 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/json" + "io" + "net/http" + "srs-proxy/errors" + "srs-proxy/logger" + "strings" +) + +type SrsClient struct { + Id string `json:"id"` + Vhost string `json:"vhost"` + Stream string `json:"stream"` + Ip string `json:"ip"` + PageUrl string `json:"pageUrl"` + SwfUrl string `json:"swfUrl"` + TcUrl string `json:"tcUrl"` + Url string `json:"url"` + Name string `json:"name"` + Type string `json:"type"` + Publish bool `json:"publish"` + Alive float32 `json:"alive"` + SendBytes int `json:"send_bytes"` + RecvBytes int `json:"recv_bytes"` +} + +type SrsApiCodeResponse struct { + Code int `json:"code"` +} + +type SrsAPICommonResponse struct { + SrsApiCodeResponse + Server string `json:"server"` + Service string `json:"service"` + Pid string `json:"pid"` +} + +type SrsClientResponse struct { + SrsAPICommonResponse + Client SrsClient `json:"client"` +} + +type SrsClientsResponse struct { + SrsAPICommonResponse + Clients []SrsClient `json:"clients"` +} + +type SrsKbps struct { + Recv_30s uint32 `json:"recv_30s"` + Send_30s uint32 `json:"send_30s"` +} + +type SrsPublish struct { + Active bool `json:"active"` + Cid string `json:"cid"` +} + +type SrsVideo struct { + Codec string `json:"codec"` + Profile string `json:"profile"` + Level string `json:"level"` + Width uint32 `json:"width"` + Height uint32 `json:"height"` +} + +type SrsAudio struct { + Codec string `json:"codec"` + Sample_rate uint32 `json:"sample_rate"` + Channel uint8 `json:"channel"` + Profile string `json:"profile"` +} + +type SrsStream struct { + Id string `json:"id"` + Name string `json:"name"` + Vhost string `json:"vhost"` + App string `json:"app"` + TcUrl string `json:"tcUrl"` + Url string `json:"url"` + Live_ms uint64 `json:"live_ms"` + Clients uint32 `json:"clients"` + Frames uint32 `json:"frames"` + Send_bytes uint32 `json:"send_bytes"` + Recv_bytes uint32 `json:"recv_bytes"` + Kbps SrsKbps `json:"kbps"` + Publish SrsPublish `json:"publish"` + Video SrsVideo `json:"video"` + Audio SrsAudio `json:"audio"` +} + +type SrsStreamResponse struct { + SrsAPICommonResponse + Stream SrsStream `json:"stream"` +} + +type SrsStreamsResponse struct { + SrsAPICommonResponse + Streams []SrsStream `json:"streams"` +} + +type SrsHTTPApi struct { + Enabled bool `json:"enabled"` + Listen string `json:"listen"` + Crossdomain bool `json:"crossdomain"` + Raw_api SrsRawApi `json:"raw_api"` +} + +type SrsRawApi struct { + Enabled bool `json:"enabled"` + Allow_reload bool `json:"allow_reload"` + Allow_query bool `json:"allow_query"` + Allow_update bool `json:"allow_update"` +} + +type SrsRawResponse struct { + SrsApiCodeResponse + Http_api SrsHTTPApi `json:"http_api"` +} + +type SrsRawReloadResponse struct { + SrsApiCodeResponse +} + +type SrsRawReloadFetchData struct { + Err int `json:"err"` + Msg string `json:"msg"` + State int `json:"state"` + Rid string `json:"rid"` +} + +type SrsRawReloadFetchResponse struct { + SrsApiCodeResponse + Data SrsRawReloadFetchData `json:"data"` +} + +type SrsApiProxy struct { +} + +func (v *SrsApiProxy) proxySrsAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error { + if strings.HasPrefix(r.URL.Path, "/api/v1/clients") { + return proxySrsClientsAPI(ctx, servers, w, r) + } else if strings.HasPrefix(r.URL.Path, "/api/v1/streams") { + return proxySrsStreamsAPI(ctx, servers, w, r) + } else if strings.HasPrefix(r.URL.Path, "/api/v1/raw") { + return proxySrsRawAPI(ctx, servers, w, r) + } + return nil +} + +// handle srs clients api /api/v1/clients +func proxySrsClientsAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + + clientId := "" + if strings.HasPrefix(r.URL.Path, "/api/v1/clients/") { + clientId = r.URL.Path[len("/api/v1/clients/"):] + } + logger.Df(ctx, "%v %v clientId=%v", r.Method, r.URL.Path, clientId) + + body, err := io.ReadAll(r.Body) + if err != nil { + apiError(ctx, w, r, err, http.StatusInternalServerError) + return errors.Wrapf(err, "read request body err") + } + + switch r.Method { + case http.MethodDelete: + for _, server := range servers { + if ret, err := server.ApiRequest(ctx, r, body); err == nil { + logger.Df(ctx, "response %v", string(ret)) + var res SrsApiCodeResponse + if err := json.Unmarshal(ret, &res); err == nil && res.Code == 0 { + apiResponse(ctx, w, r, res) + return nil + } + } + } + + err := errors.Errorf("clientId %v not found in server", clientId) + apiError(ctx, w, r, err, http.StatusNotFound) + return err + case http.MethodGet: + if len(clientId) > 0 { + for _, server := range servers { + var client SrsClientResponse + if ret, err := server.ApiRequest(ctx, r, body); err == nil { + if err := json.Unmarshal(ret, &client); err == nil && client.Code == 0 { + apiResponse(ctx, w, r, client) + return nil + } + } + } + } else { // get all clients + var clients SrsClientsResponse + for _, server := range servers { + var res SrsClientsResponse + if ret, err := server.ApiRequest(ctx, r, body); err == nil { + if err := json.Unmarshal(ret, &res); err == nil && res.Code == 0 { + clients.Clients = append(clients.Clients, res.Clients...) + } + } + } + + apiResponse(ctx, w, r, clients) + return nil + } + default: + logger.Df(ctx, "/api/v1/clients %v", r.Method) + } + return nil +} + +func proxySrsStreamsAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + + streamId := "" + if strings.HasPrefix(r.URL.Path, "/api/v1/streams/") { + streamId = r.URL.Path[len("/api/v1/streams/"):] + } + logger.Df(ctx, "%v %v streamId=%v", r.Method, r.URL.Path, streamId) + + body, err := io.ReadAll(r.Body) + if err != nil { + apiError(ctx, w, r, err, http.StatusInternalServerError) + return errors.Wrapf(err, "read request body err") + } + if r.Method != http.MethodGet { + err := errors.Errorf("Unsupported http method type %v", r.Method) + apiError(ctx, w, r, err, http.StatusBadRequest) + return err + } + if len(streamId) > 0 { + var stream SrsStreamResponse + for _, server := range servers { + if ret, err := server.ApiRequest(ctx, r, body); err == nil { + if err := json.Unmarshal(ret, &stream); err == nil && stream.Code == 0 { + apiResponse(ctx, w, r, stream) + return nil + } + } + } + ret := SrsApiCodeResponse{ + Code: 2048, + } + apiResponse(ctx, w, r, ret) + return nil + } else { + var streams SrsStreamsResponse + for _, server := range servers { + var res SrsStreamsResponse + if ret, err := server.ApiRequest(ctx, r, body); err == nil { + if err := json.Unmarshal(ret, &res); err == nil && res.Code == 0 { + streams.Streams = append(streams.Streams, res.Streams...) + } + } + } + + apiResponse(ctx, w, r, streams) + return nil + } +} + +func proxySrsRawAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + + rpc := r.URL.Query().Get("rpc") + logger.Df(ctx, "%v, rpc=%v", r.URL.Path, rpc) + body, err := io.ReadAll(r.Body) + if err != nil { + apiError(ctx, w, r, err, http.StatusInternalServerError) + return errors.Wrapf(err, "read request body err") + } + + for _, server := range servers { + if ret, err := server.ApiRequest(ctx, r, body); err == nil { + if rpc == "raw" { + // return the first success response + var raw SrsRawResponse + if err := json.Unmarshal(ret, &raw); err == nil && raw.Code == 0 { + raw.Http_api.Listen = envHttpAPI() + apiResponse(ctx, w, r, raw) + return nil + } + } else if rpc == "reload" { + var res SrsRawReloadResponse + err := json.Unmarshal(ret, &res) + logger.Df(ctx, "%v %v %v %v", server.IP, r.URL.Path, res, err) + } else if rpc == "reload-fetch" { + var res SrsRawReloadFetchResponse + err := json.Unmarshal(ret, &res) + logger.Df(ctx, "%v %v %v %v", server.IP, r.URL.Path, res, err) + } else { + var code SrsApiCodeResponse + if err := json.Unmarshal(ret, &code); err == nil { + logger.Df(ctx, "%v %v", r.URL.Path, code) + } + } + } + } + + return nil +} diff --git a/proxy/srs.go b/proxy/srs.go index d05a39c610..8428777bcf 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -4,10 +4,13 @@ package main import ( + "bytes" "context" "encoding/json" "fmt" + "io" "math/rand" + "net/http" "os" "strconv" "strings" @@ -97,6 +100,35 @@ func (v *SRSServer) Format(f fmt.State, c rune) { } } +func (v *SRSServer) ApiRequest(ctx context.Context, r *http.Request, body []byte) ([]byte, error) { + var url string + // if the v.API[0] contains ip address, e.g. 127.0.0.1:1985, then use it as the ip address + if strings.Contains(v.API[0], ":") && strings.Index(v.API[0], ":") > 0 { + url = "http://" + v.API[0] + r.URL.Path + } else { + url = "http://" + v.IP + ":" + v.API[0] + r.URL.Path + } + + if r.URL.RawQuery != "" { + url += "?" + r.URL.RawQuery + } + + if req, err := http.NewRequestWithContext(ctx, r.Method, url, bytes.NewReader(body)); err != nil { + return nil, errors.Wrapf(err, "create request to %v", url) + } else if res, err := http.DefaultClient.Do(req); err != nil { + return nil, errors.Wrapf(err, "send request to %v", url) + } else { + defer res.Body.Close() + if ret, err := io.ReadAll(res.Body); err != nil { + return nil, errors.Wrapf(err, "read http respose error") + } else if !isHttpStatusOK(res.StatusCode) { + return ret, errors.Errorf("http response status code %v", res.StatusCode) + } else { + return ret, nil + } + } +} + func NewSRSServer(opts ...func(*SRSServer)) *SRSServer { v := &SRSServer{} for _, opt := range opts { @@ -158,6 +190,8 @@ type SRSLoadBalancer interface { StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error // Load the WebRTC streaming by ufrag, the ICE username. LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) + // proxy http api to srs + ProxyHTTPAPI(ctx context.Context, w http.ResponseWriter, r *http.Request) error } // srsLoadBalancer is the global SRS load balancer. @@ -165,6 +199,7 @@ var srsLoadBalancer SRSLoadBalancer // srsMemoryLoadBalancer stores state in memory. type srsMemoryLoadBalancer struct { + *SrsApiProxy // All available SRS servers, key is server ID. servers sync.Map[string, *SRSServer] // The picked server to servce client by specified stream URL, key is stream url. @@ -287,7 +322,17 @@ func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag str } } +func (v *srsMemoryLoadBalancer) ProxyHTTPAPI(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + services := make([]*SRSServer, v.servers.Size()) + v.servers.Range(func(_ string, value *SRSServer) bool { + services = append(services, value) + return true + }) + return v.proxySrsAPI(ctx, services, w, r) +} + type srsRedisLoadBalancer struct { + *SrsApiProxy // The redis client sdk. rdb *redis.Client } @@ -528,6 +573,40 @@ func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag stri return &actual, nil } +func (v *srsRedisLoadBalancer) ProxyHTTPAPI(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + // Query all servers from redis, in json string. + var serverKeys []string + if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { + if err := json.Unmarshal(b, &serverKeys); err != nil { + return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) + } + } + + // No server found, failed. + if len(serverKeys) == 0 { + err := errors.New("servers empty") + apiError(ctx, w, r, err, http.StatusInternalServerError) + return err + } + + // TODO get all SRSServer + var srsServers []*SRSServer + + for _, key := range serverKeys { + var server SRSServer + if b, err := v.rdb.Get(ctx, key).Bytes(); err == nil { + if err := json.Unmarshal(b, &server); err != nil { + return errors.Wrapf(err, "unmarshal servers %v, %v", key, string(b)) + } + srsServers = append(srsServers, &server) + logger.Df(ctx, "srsServer: %v", server) + } + } + + return v.proxySrsAPI(ctx, srsServers, w, r) +} + func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string { return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag) } @@ -549,5 +628,5 @@ func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string { } func (v *srsRedisLoadBalancer) redisKeyServers() string { - return fmt.Sprintf("srs-proxy-all-servers") + return "srs-proxy-all-servers" } diff --git a/proxy/sync/map.go b/proxy/sync/map.go index 75db12f9a9..fe35dc91eb 100644 --- a/proxy/sync/map.go +++ b/proxy/sync/map.go @@ -43,3 +43,13 @@ func (m *Map[K, V]) Range(f func(key K, value V) bool) { func (m *Map[K, V]) Store(key K, value V) { m.m.Store(key, value) } + +func (m *Map[K, V]) Size() uint32 { + size := uint32(0) + m.m.Range(func(_, _ any) bool { + size++ + return true + }) + + return size +} diff --git a/proxy/utils.go b/proxy/utils.go index f3c3930762..fd84d6dd0e 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -32,7 +32,7 @@ func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, da b, err := json.Marshal(data) if err != nil { - apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data)) + apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data), http.StatusInternalServerError) return } @@ -41,10 +41,10 @@ func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, da w.Write(b) } -func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) { +func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error, code int) { logger.Wf(ctx, "HTTP API error %+v", err) w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(code) fmt.Fprintln(w, fmt.Sprintf("%v", err)) } @@ -69,6 +69,10 @@ func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool { return false } +func isHttpStatusOK(v int) bool { + return v >= 200 && v < 300 +} + func parseGracefullyQuitTimeout() (time.Duration, error) { if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) @@ -250,8 +254,9 @@ func parseSRTStreamID(sid string) (host, resource string, err error) { } // parseListenEndpoint parse the listen endpoint as: -// port The tcp listen port, like 1935. -// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935 +// +// port The tcp listen port, like 1935. +// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935 func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) { // If no colon in ep, it's port in string. if !strings.Contains(ep, ":") {