diff --git a/go.mod b/go.mod index e870e266..1a1652d1 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/pion/interceptor -go 1.20 +go 1.21 require ( github.com/pion/logging v0.2.3 diff --git a/internal/test/mock_stream.go b/internal/test/mock_stream.go index bf96e31b..e791ac8a 100644 --- a/internal/test/mock_stream.go +++ b/internal/test/mock_stream.go @@ -41,6 +41,7 @@ type RTPWithError struct { // RTCPWithError is used to send a batch of rtcp packets or an error on a channel type RTCPWithError struct { Packets []rtcp.Packet + Attr interceptor.Attributes Err error } @@ -107,21 +108,21 @@ func NewMockStream(info *interceptor.StreamInfo, i interceptor.Interceptor) *Moc go func() { buf := make([]byte, 1500) for { - i, _, err := s.rtcpReader.Read(buf, interceptor.Attributes{}) + i, attr, err := s.rtcpReader.Read(buf, interceptor.Attributes{}) if err != nil { if !errors.Is(err, io.EOF) { - s.rtcpInModified <- RTCPWithError{Err: err} + s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err} } return } pkts, err := rtcp.Unmarshal(buf[:i]) if err != nil { - s.rtcpInModified <- RTCPWithError{Err: err} + s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err} return } - s.rtcpInModified <- RTCPWithError{Packets: pkts} + s.rtcpInModified <- RTCPWithError{Attr: attr, Packets: pkts} } }() go func() { diff --git a/pkg/bwe/acknowledgment.go b/pkg/bwe/acknowledgment.go new file mode 100644 index 00000000..7a241d92 --- /dev/null +++ b/pkg/bwe/acknowledgment.go @@ -0,0 +1,38 @@ +package bwe + +import ( + "fmt" + "time" +) + +type ECN uint8 + +const ( + //nolint:misspell + // ECNNonECT signals Non ECN-Capable Transport, Non-ECT + ECNNonECT ECN = iota // 00 + + //nolint:misspell + // ECNECT1 signals ECN Capable Transport, ECT(0) + ECNECT1 // 01 + + //nolint:misspell + // ECNECT0 signals ECN Capable Transport, ECT(1) + ECNECT0 // 10 + + // ECNCE signals ECN Congestion Encountered, CE + ECNCE // 11 +) + +type Acknowledgment struct { + SeqNr int64 + Size uint16 + Departure time.Time + Arrived bool + Arrival time.Time + ECN ECN +} + +func (a Acknowledgment) String() string { + return fmt.Sprintf("seq=%v, departure=%v, arrival=%v", a.SeqNr, a.Departure, a.Arrival) +} diff --git a/pkg/bwe/arrival_group_accumulator.go b/pkg/bwe/arrival_group_accumulator.go new file mode 100644 index 00000000..088d3fda --- /dev/null +++ b/pkg/bwe/arrival_group_accumulator.go @@ -0,0 +1,48 @@ +package bwe + +import ( + "time" +) + +type arrivalGroup []Acknowledgment + +type arrivalGroupAccumulator struct { + next arrivalGroup + burstInterval time.Duration + maxBurstDuration time.Duration +} + +func newArrivalGroupAccumulator() *arrivalGroupAccumulator { + return &arrivalGroupAccumulator{ + next: make([]Acknowledgment, 0), + burstInterval: 5 * time.Millisecond, + maxBurstDuration: 100 * time.Millisecond, + } +} + +func (a *arrivalGroupAccumulator) onPacketAcked(ack Acknowledgment) arrivalGroup { + if len(a.next) == 0 { + a.next = append(a.next, ack) + return nil + } + + if ack.Departure.Sub(a.next[0].Departure) < a.burstInterval { + a.next = append(a.next, ack) + return nil + } + + sendTimeDelta := ack.Departure.Sub(a.next[0].Departure) + arrivalTimeDeltaLast := ack.Arrival.Sub(a.next[len(a.next)-1].Arrival) + arrivalTimeDeltaFirst := ack.Arrival.Sub(a.next[0].Arrival) + propagationDelta := arrivalTimeDeltaFirst - sendTimeDelta + + if propagationDelta < 0 && arrivalTimeDeltaLast <= a.burstInterval && arrivalTimeDeltaFirst < a.maxBurstDuration { + a.next = append(a.next, ack) + return nil + } + + group := make(arrivalGroup, len(a.next)) + copy(group, a.next) + a.next = arrivalGroup{ack} + return group +} diff --git a/pkg/bwe/arrival_group_accumulator_test.go b/pkg/bwe/arrival_group_accumulator_test.go new file mode 100644 index 00000000..e88cb3b8 --- /dev/null +++ b/pkg/bwe/arrival_group_accumulator_test.go @@ -0,0 +1,239 @@ +package bwe + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestArrivalGroupAccumulator(t *testing.T) { + triggerNewGroupElement := Acknowledgment{ + Departure: time.Time{}.Add(time.Second), + Arrival: time.Time{}.Add(time.Second), + } + cases := []struct { + name string + log []Acknowledgment + exp []arrivalGroup + }{ + { + name: "emptyCreatesNoGroups", + log: []Acknowledgment{}, + exp: []arrivalGroup{}, + }, + { + name: "createsSingleElementGroup", + log: []Acknowledgment{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(time.Millisecond), + }, + }, + }, + }, + { + name: "createsTwoElementGroup", + log: []Acknowledgment{ + { + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }}, + }, + { + name: "createsTwoArrivalGroups1", + log: []Acknowledgment{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(24 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }, + { + { + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(24 * time.Millisecond), + }, + }, + }, + }, + { + name: "createsTwoArrivalGroups2", + log: []Acknowledgment{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(30 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }, + { + { + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(30 * time.Millisecond), + }, + }, + }, + }, + { + name: "ignoresOutOfOrderPackets", + log: []Acknowledgment{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(34 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(8 * time.Millisecond), + Arrival: time.Time{}.Add(30 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + }, + { + { + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(34 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(8 * time.Millisecond), + Arrival: time.Time{}.Add(30 * time.Millisecond), + }, + }, + }, + }, + { + name: "newGroupBecauseOfInterDepartureTime", + log: []Acknowledgment{ + { + SeqNr: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SeqNr: 1, + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SeqNr: 2, + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SeqNr: 3, + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + SeqNr: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SeqNr: 1, + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + }, + { + { + SeqNr: 2, + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SeqNr: 3, + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + }, + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + aga := newArrivalGroupAccumulator() + received := []arrivalGroup{} + for _, ack := range tc.log { + next := aga.onPacketAcked(ack) + if next != nil { + received = append(received, next) + } + } + assert.Equal(t, tc.exp, received) + }) + } +} diff --git a/pkg/bwe/delay_rate_controller.go b/pkg/bwe/delay_rate_controller.go new file mode 100644 index 00000000..e62c8e73 --- /dev/null +++ b/pkg/bwe/delay_rate_controller.go @@ -0,0 +1,86 @@ +package bwe + +import ( + "time" + + "github.com/pion/logging" +) + +const maxSamples = 1000 + +type DelayRateController struct { + log logging.LeveledLogger + aga *arrivalGroupAccumulator + last arrivalGroup + kf *kalmanFilter + od *overuseDetector + rc *rateController + latestUsage usage + samples int +} + +func NewDelayRateController(initialRate int) *DelayRateController { + return &DelayRateController{ + log: logging.NewDefaultLoggerFactory().NewLogger("bwe_delay_rate_controller"), + aga: newArrivalGroupAccumulator(), + last: []Acknowledgment{}, + kf: newKalmanFilter(), + od: newOveruseDetector(true), + rc: newRateController(initialRate), + latestUsage: 0, + samples: 0, + } +} + +func (c *DelayRateController) OnPacketAcked(ack Acknowledgment) { + next := c.aga.onPacketAcked(ack) + if next == nil { + return + } + if len(next) == 0 { + // ignore empty groups, should never occur + return + } + if len(c.last) == 0 { + c.last = next + return + } + + prevSize := groupSize(c.last) + nextSize := groupSize(next) + sizeDelta := nextSize - prevSize + + interArrivalTime := next[len(next)-1].Arrival.Sub(c.last[len(c.last)-1].Arrival) + interDepartureTime := next[len(next)-1].Departure.Sub(c.last[len(c.last)-1].Departure) + interGroupDelay := interArrivalTime - interDepartureTime + estimate := c.kf.update(float64(interGroupDelay.Milliseconds()), float64(sizeDelta)) + c.samples++ + c.latestUsage = c.od.update(ack.Arrival, estimate, c.samples) + c.last = next + c.log.Tracef( + "ts=%v.%06d, seq=%v, size=%v, interArrivalTime=%v, interDepartureTime=%v, interGroupDelay=%v, estimate=%v, threshold=%v, usage=%v, state=%v", + c.last[0].Departure.UTC().Format("2006/01/02 15:04:05"), + c.last[0].Departure.UTC().Nanosecond()/1e3, + next[0].SeqNr, + nextSize, + interArrivalTime.Microseconds(), + interDepartureTime.Microseconds(), + interGroupDelay.Microseconds(), + estimate, + c.od.delayThreshold, + int(c.latestUsage), + int(c.rc.s), + ) +} + +func (c *DelayRateController) Update(ts time.Time, lastDeliveryRate int, rtt time.Duration) int { + return c.rc.update(ts, c.latestUsage, lastDeliveryRate, rtt) +} + +func groupSize(group arrivalGroup) int { + sum := 0 + for _, ack := range group { + sum += int(ack.Size) + } + return sum +} diff --git a/pkg/bwe/delivery_rate_estimator.go b/pkg/bwe/delivery_rate_estimator.go new file mode 100644 index 00000000..9d18e3a7 --- /dev/null +++ b/pkg/bwe/delivery_rate_estimator.go @@ -0,0 +1,87 @@ +package bwe + +import ( + "container/heap" + "time" +) + +type deliveryRateHeapItem struct { + arrival time.Time + size int +} + +type deliveryRateHeap []deliveryRateHeapItem + +// Len implements heap.Interface. +func (d deliveryRateHeap) Len() int { + return len(d) +} + +// Less implements heap.Interface. +func (d deliveryRateHeap) Less(i int, j int) bool { + return d[i].arrival.Before(d[j].arrival) +} + +// Pop implements heap.Interface. +func (d *deliveryRateHeap) Pop() any { + old := *d + n := len(old) + x := old[n-1] + *d = old[0 : n-1] + return x +} + +// Push implements heap.Interface. +func (d *deliveryRateHeap) Push(x any) { + *d = append(*d, x.(deliveryRateHeapItem)) +} + +// Swap implements heap.Interface. +func (d deliveryRateHeap) Swap(i int, j int) { + d[i], d[j] = d[j], d[i] +} + +type deliveryRateEstimator struct { + window time.Duration + latestArrival time.Time + history *deliveryRateHeap +} + +func newDeliveryRateEstimator(window time.Duration) *deliveryRateEstimator { + return &deliveryRateEstimator{ + window: window, + latestArrival: time.Time{}, + history: &deliveryRateHeap{}, + } +} + +func (e *deliveryRateEstimator) OnPacketAcked(arrival time.Time, size int) { + if arrival.After(e.latestArrival) { + e.latestArrival = arrival + } + heap.Push(e.history, deliveryRateHeapItem{ + arrival: arrival, + size: size, + }) +} + +func (e *deliveryRateEstimator) GetRate() int { + deadline := e.latestArrival.Add(-e.window) + for len(*e.history) > 0 && (*e.history)[0].arrival.Before(deadline) { + heap.Pop(e.history) + } + earliest := e.latestArrival + sum := 0 + for _, i := range *e.history { + if i.arrival.Before(earliest) { + earliest = i.arrival + } + sum += i.size + } + d := e.latestArrival.Sub(earliest) + if d == 0 { + return 0 + } + rate := 8 * float64(sum) / d.Seconds() + return int(rate) +} diff --git a/pkg/bwe/delivery_rate_estimator_test.go b/pkg/bwe/delivery_rate_estimator_test.go new file mode 100644 index 00000000..15ea3665 --- /dev/null +++ b/pkg/bwe/delivery_rate_estimator_test.go @@ -0,0 +1,74 @@ +package bwe + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDeliveryRateEstimator(t *testing.T) { + type ack struct { + arrival time.Time + size int + } + cases := []struct { + window time.Duration + acks []ack + expectedRate int + }{ + { + window: 0, + acks: []ack{}, + expectedRate: 0, + }, + { + window: time.Second, + acks: []ack{}, + expectedRate: 0, + }, + { + window: time.Second, + acks: []ack{ + {time.Time{}, 1200}, + }, + expectedRate: 0, + }, + { + window: time.Second, + acks: []ack{ + {time.Time{}.Add(time.Millisecond), 1200}, + }, + expectedRate: 0, + }, + { + window: time.Second, + acks: []ack{ + {time.Time{}.Add(time.Second), 1200}, + {time.Time{}.Add(1500 * time.Millisecond), 1200}, + {time.Time{}.Add(2 * time.Second), 1200}, + }, + expectedRate: 28800, + }, + { + window: time.Second, + acks: []ack{ + {time.Time{}.Add(500 * time.Millisecond), 1200}, + {time.Time{}.Add(time.Second), 1200}, + {time.Time{}.Add(1500 * time.Millisecond), 1200}, + {time.Time{}.Add(2 * time.Second), 1200}, + }, + expectedRate: 28800, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + e := newDeliveryRateEstimator(tc.window) + for _, ack := range tc.acks { + e.OnPacketAcked(ack.arrival, ack.size) + } + assert.Equal(t, tc.expectedRate, e.GetRate()) + }) + } +} diff --git a/pkg/bwe/deprecated_bwe_api.go b/pkg/bwe/deprecated_bwe_api.go new file mode 100644 index 00000000..15e367ff --- /dev/null +++ b/pkg/bwe/deprecated_bwe_api.go @@ -0,0 +1,126 @@ +package bwe + +import ( + "errors" + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/pkg/ccfb" + "github.com/pion/rtcp" +) + +// GCCFactory creates a new cc.BandwidthEstimator +func GCCFactory() (cc.BandwidthEstimator, error) { + return &GCC{ + lock: sync.Mutex{}, + sbwe: NewSendSideController(1_000_000, 100_000, 100_000_000), + rate: 1_000_000, + }, nil +} + +// GCC implements cc.BandwidthEstimator +type GCC struct { + lock sync.Mutex + sbwe *SendSideController + rate int + updateCB func(int) +} + +// AddStream implements cc.BandwidthEstimator. +// Called by cc.Interceptor +func (g *GCC) AddStream(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + return writer +} + +// Close implements cc.BandwidthEstimator. +// Called by cc.Interceptor +func (g *GCC) Close() error { + // GCC does not need to be closed + return nil +} + +// GetStats implements cc.BandwidthEstimator. +// Called by application +func (g *GCC) GetStats() map[string]interface{} { + g.lock.Lock() + defer g.lock.Unlock() + return map[string]interface{}{ + "warning": "GetStats is deprecated", + "lossTargetBitrate": 0, + "averageLoss": 0, + "delayTargetBitrate": 0, + "delayMeasurement": 0, + "delayEstimate": 0, + "delayThreshold": 0, + "usage": 0, + "state": 0, + } +} + +// GetTargetBitrate implements cc.BandwidthEstimator. +// Called by application +func (g *GCC) GetTargetBitrate() int { + g.lock.Lock() + defer g.lock.Unlock() + return g.rate +} + +// OnTargetBitrateChange implements cc.BandwidthEstimator. +// Called by application +func (g *GCC) OnTargetBitrateChange(f func(bitrate int)) { + g.lock.Lock() + defer g.lock.Unlock() + g.updateCB = f +} + +// WriteRTCP implements cc.BandwidthEstimator. +// Called by cc.Interceptor +func (g *GCC) WriteRTCP(_ []rtcp.Packet, attr interceptor.Attributes) error { + reports, ok := attr.Get(ccfb.CCFBAttributesKey).([]ccfb.Report) + if !ok { + return errors.New("warning: GCC requires CCFB interceptor to be configured before the CC interceptor") + } + now := time.Now() + for _, report := range reports { + acks, rtt := readReport(report) + g.update(now, rtt, acks) + } + return nil +} + +func (g *GCC) update(now time.Time, rtt time.Duration, acks []Acknowledgment) { + g.lock.Lock() + defer g.lock.Unlock() + oldRate := g.rate + + g.rate = g.sbwe.OnAcks(now, rtt, acks) + + if oldRate != g.rate && g.updateCB != nil { + g.updateCB(g.rate) + } +} + +func readReport(report ccfb.Report) ([]Acknowledgment, time.Duration) { + acks := []Acknowledgment{} + latestAcked := Acknowledgment{} + for _, prs := range report.SSRCToPacketReports { + for _, pr := range prs { + ack := Acknowledgment{ + SeqNr: pr.SeqNr, + Size: pr.Size, + Departure: pr.Departure, + Arrived: pr.Arrived, + Arrival: pr.Arrival, + ECN: ECN(pr.ECN), + } + if ack.Arrival.After(latestAcked.Arrival) { + latestAcked = ack + } + acks = append(acks, ack) + } + } + rtt := MeasureRTT(report.Departure, report.Arrival, latestAcked.Departure, latestAcked.Arrival) + return acks, rtt +} diff --git a/pkg/bwe/exponential_moving_average.go b/pkg/bwe/exponential_moving_average.go new file mode 100644 index 00000000..9971a1e7 --- /dev/null +++ b/pkg/bwe/exponential_moving_average.go @@ -0,0 +1,19 @@ +package bwe + +type exponentialMovingAverage struct { + initialized bool + alpha float64 + average float64 + variance float64 +} + +func (a *exponentialMovingAverage) update(sample float64) { + if !a.initialized { + a.average = sample + a.initialized = true + } else { + delta := sample - a.average + a.average = a.alpha*sample + (1-a.alpha)*a.average + a.variance = (1-a.alpha)*a.variance + a.alpha*(1-a.alpha)*(delta*delta) + } +} diff --git a/pkg/bwe/exponential_moving_average_test.go b/pkg/bwe/exponential_moving_average_test.go new file mode 100644 index 00000000..89976b51 --- /dev/null +++ b/pkg/bwe/exponential_moving_average_test.go @@ -0,0 +1,122 @@ +package bwe + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +// python to generate test cases: +// import numpy as np +// import pandas as pd +// data = np.random.randint(1, 10, size=10) +// df = pd.DataFrame(data) +// expectedAvg = df.ewm(alpha=0.9, adjust=False).mean() +// expectedVar = df.ewm(alpha=0.9, adjust=False).var(bias=True) + +func TestExponentialMovingAverage(t *testing.T) { + cases := []struct { + alpha float64 + updates []float64 + expectedAvg []float64 + expectedVar []float64 + }{ + { + alpha: 0.9, + updates: []float64{}, + expectedAvg: []float64{}, + expectedVar: []float64{}, + }, + { + alpha: 0.9, + updates: []float64{1, 2, 3, 4}, + expectedAvg: []float64{ + 1.000, + 1.900, + 2.890, + 3.889, + }, + expectedVar: []float64{ + 0.000000, + 0.090000, + 0.117900, + 0.122679, + }, + }, + { + alpha: 0.9, + updates: []float64{8, 8, 5, 1, 3, 1, 8, 2, 8, 9}, + expectedAvg: []float64{ + 8.000000, + 8.000000, + 5.300000, + 1.430000, + 2.843000, + 1.184300, + 7.318430, + 2.531843, + 7.453184, + 8.845318, + }, + expectedVar: []float64{ + 0.000000, + 0.000000, + 0.810000, + 1.745100, + 0.396351, + 0.345334, + 4.215372, + 2.967250, + 2.987792, + 0.514117, + }, + }, + { + alpha: 0.9, + updates: []float64{7, 5, 6, 7, 3, 6, 8, 9, 5, 5}, + expectedAvg: []float64{ + 7.000000, + 5.200000, + 5.920000, + 6.892000, + 3.389200, + 5.738920, + 7.773892, + 8.877389, + 5.387739, + 5.038774, + }, + expectedVar: []float64{ + 0.000000, + 0.360000, + 0.093600, + 0.114336, + 1.374723, + 0.750937, + 0.535217, + 0.188822, + 1.371955, + 0.150726, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + a := exponentialMovingAverage{ + alpha: tc.alpha, + average: 0, + variance: 0, + } + avgs := []float64{} + vars := []float64{} + for _, u := range tc.updates { + a.update(u) + avgs = append(avgs, a.average) + vars = append(vars, a.variance) + } + assert.InDeltaSlice(t, tc.expectedAvg, avgs, 0.001) + assert.InDeltaSlice(t, tc.expectedVar, vars, 0.001) + }) + } +} diff --git a/pkg/bwe/kalman.go b/pkg/bwe/kalman.go new file mode 100644 index 00000000..5aed2eb3 --- /dev/null +++ b/pkg/bwe/kalman.go @@ -0,0 +1,92 @@ +package bwe + +import ( + "math" +) + +type kalmanFilter struct { + state [2]float64 // [slope, offset] + + processNoise [2]float64 + e [2][2]float64 + avgNoise float64 + varNoise float64 +} + +type kalmanOption func(*kalmanFilter) + +func initSlope(e float64) kalmanOption { + return func(k *kalmanFilter) { + k.state[0] = e + } +} + +func newKalmanFilter(opts ...kalmanOption) *kalmanFilter { + kf := &kalmanFilter{ + state: [2]float64{8.0 / 512.0, 0}, + processNoise: [2]float64{1e-13, 1e-3}, + e: [2][2]float64{{100.0, 0}, {0, 1e-1}}, + varNoise: 50.0, + } + for _, opt := range opts { + opt(kf) + } + return kf +} + +func (k *kalmanFilter) update(timeDelta float64, sizeDelta float64) float64 { + k.e[0][0] += k.processNoise[0] + k.e[1][1] += k.processNoise[1] + + h := [2]float64{sizeDelta, 1.0} + Eh := [2]float64{ + k.e[0][0]*h[0] + k.e[0][1]*h[1], + k.e[1][0]*h[0] + k.e[1][1]*h[1], + } + residual := timeDelta - (k.state[0]*h[0] + k.state[1]*h[1]) + + maxResidual := 3.0 * math.Sqrt(k.varNoise) + if math.Abs(residual) < maxResidual { + k.updateNoiseEstimate(residual, timeDelta) + } else { + if residual < 0 { + k.updateNoiseEstimate(-maxResidual, timeDelta) + } else { + k.updateNoiseEstimate(maxResidual, timeDelta) + } + } + + denom := k.varNoise + h[0]*Eh[0] + h[1]*Eh[1] + + K := [2]float64{ + Eh[0] / denom, Eh[1] / denom, + } + + IKh := [2][2]float64{ + {1.0 - K[0]*h[0], -K[0] * h[1]}, + {-K[1] * h[0], 1.0 - K[1]*h[1]}, + } + + e00 := k.e[0][0] + e01 := k.e[0][1] + + k.e[0][0] = e00*IKh[0][0] + k.e[1][0]*IKh[0][1] + k.e[0][1] = e01*IKh[0][0] + k.e[1][1]*IKh[0][1] + k.e[1][0] = e00*IKh[1][0] + k.e[1][0]*IKh[1][1] + k.e[1][1] = e01*IKh[1][0] + k.e[1][1]*IKh[1][1] + + k.state[0] = k.state[0] + K[0]*residual + k.state[1] = k.state[1] + K[1]*residual + + return k.state[1] +} + +func (k *kalmanFilter) updateNoiseEstimate(residual float64, timeDelta float64) { + alpha := 0.002 + beta := math.Pow(1-alpha, timeDelta*30.0/1000.0) + k.avgNoise = beta*k.avgNoise + (1-beta)*residual + k.varNoise = beta*k.varNoise + (1-beta)*(k.avgNoise-residual)*(k.avgNoise-residual) + if k.varNoise < 1 { + k.varNoise = 1 + } +} diff --git a/pkg/bwe/loss_rate_controller.go b/pkg/bwe/loss_rate_controller.go new file mode 100644 index 00000000..22a37d01 --- /dev/null +++ b/pkg/bwe/loss_rate_controller.go @@ -0,0 +1,53 @@ +package bwe + +type LossRateController struct { + bitrate int + min, max float64 + + packetsSinceLastUpdate int + arrivedSinceLastUpdate int + lostSinceLastUpdate int +} + +func NewLossRateController(initialRate, minRate, maxRate int) *LossRateController { + return &LossRateController{ + bitrate: initialRate, + min: float64(minRate), + max: float64(maxRate), + packetsSinceLastUpdate: 0, + arrivedSinceLastUpdate: 0, + lostSinceLastUpdate: 0, + } +} + +func (l *LossRateController) OnPacketAcked() { + l.packetsSinceLastUpdate++ + l.arrivedSinceLastUpdate++ +} + +func (l *LossRateController) OnPacketLost() { + l.packetsSinceLastUpdate++ + l.lostSinceLastUpdate++ +} + +func (l *LossRateController) Update(lastDeliveryRate int) int { + lossRate := float64(l.lostSinceLastUpdate) / float64(l.packetsSinceLastUpdate) + var target float64 + if lossRate > 0.1 { + target = float64(l.bitrate) * (1 - 0.5*lossRate) + target = max(target, l.min) + } else if lossRate < 0.02 { + target = float64(l.bitrate) * 1.05 + target = max(min(target, 1.5*float64(lastDeliveryRate)), float64(l.bitrate)) + target = min(target, l.max) + } + if target != 0 { + l.bitrate = int(target) + } + + l.packetsSinceLastUpdate = 0 + l.arrivedSinceLastUpdate = 0 + l.lostSinceLastUpdate = 0 + + return l.bitrate +} diff --git a/pkg/bwe/loss_rate_controller_test.go b/pkg/bwe/loss_rate_controller_test.go new file mode 100644 index 00000000..2871e54f --- /dev/null +++ b/pkg/bwe/loss_rate_controller_test.go @@ -0,0 +1,86 @@ +package bwe + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLossRateController(t *testing.T) { + cases := []struct { + init, min, max int + acked int + lost int + deliveredRate int + expectedRate int + }{ + {}, // all zeros + { + init: 100_000, + min: 100_000, + max: 1_000_000, + acked: 0, + lost: 0, + deliveredRate: 0, + expectedRate: 100_000, + }, + { + init: 100_000, + min: 100_000, + max: 1_000_000, + acked: 99, + lost: 1, + deliveredRate: 100_000, + expectedRate: 105_000, + }, + { + init: 100_000, + min: 100_000, + max: 1_000_000, + acked: 99, + lost: 1, + deliveredRate: 90_000, + expectedRate: 105_000, + }, + { + init: 100_000, + min: 100_000, + max: 1_000_000, + acked: 95, + lost: 5, + deliveredRate: 99_000, + expectedRate: 100_000, + }, + { + init: 100_000, + min: 50_000, + max: 1_000_000, + acked: 89, + lost: 11, + deliveredRate: 90_000, + expectedRate: 94_500, + }, + { + init: 100_000, + min: 100_000, + max: 1_000_000, + acked: 89, + lost: 11, + deliveredRate: 90_000, + expectedRate: 100_000, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + lrc := NewLossRateController(tc.init, tc.min, tc.max) + for i := 0; i < tc.acked; i++ { + lrc.OnPacketAcked() + } + for i := 0; i < tc.lost; i++ { + lrc.OnPacketLost() + } + assert.Equal(t, tc.expectedRate, lrc.Update(tc.deliveredRate)) + }) + } +} diff --git a/pkg/bwe/overuse_detector.go b/pkg/bwe/overuse_detector.go new file mode 100644 index 00000000..c5a3c512 --- /dev/null +++ b/pkg/bwe/overuse_detector.go @@ -0,0 +1,84 @@ +package bwe + +import ( + "math" + "time" +) + +const ( + kU = 0.01 + kD = 0.00018 + + maxNumDeltas = 60 +) + +type overuseDetector struct { + adaptiveThreshold bool + overUseTimeThreshold time.Duration + delayThreshold float64 + lastEstimate time.Duration + lastUpdate time.Time + firstOverUse time.Time + inOveruse bool + lastUsage usage +} + +func newOveruseDetector(adaptive bool) *overuseDetector { + return &overuseDetector{ + adaptiveThreshold: adaptive, + overUseTimeThreshold: 10 * time.Millisecond, + delayThreshold: 12.5, + lastEstimate: 0, + lastUpdate: time.Time{}, + firstOverUse: time.Time{}, + inOveruse: false, + } +} + +func (d *overuseDetector) update(ts time.Time, trend float64, numDeltas int) usage { + if numDeltas < 2 { + return usageNormal + } + modifiedTrend := float64(min(numDeltas, maxNumDeltas)) * trend + + if modifiedTrend > d.delayThreshold { + if d.firstOverUse.IsZero() { + d.firstOverUse = ts + } + if ts.Sub(d.firstOverUse) > d.overUseTimeThreshold { + d.firstOverUse = time.Time{} + d.lastUsage = usageOver + } + } else if modifiedTrend < -d.delayThreshold { + d.firstOverUse = time.Time{} + d.lastUsage = usageUnder + } else { + d.firstOverUse = time.Time{} + d.lastUsage = usageNormal + } + if d.adaptiveThreshold { + d.adaptThreshold(ts, modifiedTrend) + } + return d.lastUsage +} + +func (d *overuseDetector) adaptThreshold(ts time.Time, modifiedTrend float64) { + if d.lastUpdate.IsZero() { + d.lastUpdate = ts + } + if math.Abs(modifiedTrend) > d.delayThreshold+15 { + d.lastUpdate = ts + return + } + k := kU + if math.Abs(modifiedTrend) < d.delayThreshold { + k = kD + } + delta := ts.Sub(d.lastUpdate) + if delta > 100*time.Millisecond { + delta = 100 * time.Millisecond + } + d.delayThreshold += k * (math.Abs(modifiedTrend) - d.delayThreshold) * float64(delta.Milliseconds()) + d.delayThreshold = max(min(d.delayThreshold, 600.0), 6.0) + d.lastUpdate = ts +} diff --git a/pkg/bwe/overuse_detector_test.go b/pkg/bwe/overuse_detector_test.go new file mode 100644 index 00000000..5764c759 --- /dev/null +++ b/pkg/bwe/overuse_detector_test.go @@ -0,0 +1,186 @@ +package bwe + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestOveruseDetectorUpdate(t *testing.T) { + type estimate struct { + ts time.Time + estimate float64 + numDeltas int + } + cases := []struct { + name string + adaptive bool + values []estimate + expected []usage + }{ + { + name: "noEstimateNoUsageStatic", + adaptive: false, + values: []estimate{}, + expected: []usage{}, + }, + { + name: "overuseStatic", + adaptive: false, + values: []estimate{ + {time.Time{}, 1.0, 1}, + {time.Time{}.Add(5 * time.Millisecond), 20, 2}, + {time.Time{}.Add(20 * time.Millisecond), 30, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "normaluseStatic", + adaptive: false, + values: []estimate{{estimate: 0}}, + expected: []usage{usageNormal}, + }, + { + name: "underuseStatic", + adaptive: false, + values: []estimate{{time.Time{}, -20, 2}}, + expected: []usage{usageUnder}, + }, + { + name: "noOverUseBeforeDelayStatic", + adaptive: false, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(2 * time.Millisecond), 30, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "noOverUseIfEstimateDecreasedStatic", + adaptive: false, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(10 * time.Millisecond), 40, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + {time.Time{}.Add(35 * time.Millisecond), 3, 4}, + }, + expected: []usage{usageNormal, usageNormal, usageOver, usageNormal}, + }, + { + name: "noEstimateNoUsageAdaptive", + adaptive: true, + values: []estimate{}, + expected: []usage{}, + }, + { + name: "overuseAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}, 1, 1}, + {time.Time{}.Add(5 * time.Millisecond), 20, 2}, + {time.Time{}.Add(20 * time.Millisecond), 30, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "normaluseAdaptive", + adaptive: true, + values: []estimate{{estimate: 0}}, + expected: []usage{usageNormal}, + }, + { + name: "underuseAdaptive", + adaptive: true, + values: []estimate{{time.Time{}, -20, 2}}, + expected: []usage{usageUnder}, + }, + { + name: "noOverUseBeforeDelayAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(2 * time.Millisecond), 30, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "noOverUseIfEstimateDecreasedAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(10 * time.Millisecond), 40, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + {time.Time{}.Add(35 * time.Millisecond), 3, 4}, + }, + expected: []usage{usageNormal, usageNormal, usageOver, usageNormal}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + od := newOveruseDetector(tc.adaptive) + received := []usage{} + for _, e := range tc.values { + usage := od.update(e.ts, e.estimate, e.numDeltas) + received = append(received, usage) + } + assert.Equal(t, tc.expected, received) + }) + } +} + +func TestOveruseDetectorAdaptThreshold(t *testing.T) { + cases := []struct { + name string + od *overuseDetector + ts time.Time + estimate float64 + expectedThreshold float64 + }{ + { + name: "minThreshold", + od: &overuseDetector{}, + ts: time.Time{}, + estimate: 0, + expectedThreshold: 6, + }, + { + name: "increase", + od: &overuseDetector{ + delayThreshold: 12.5, + lastUpdate: time.Time{}.Add(time.Second), + }, + ts: time.Time{}.Add(2 * time.Second), + estimate: 25, + expectedThreshold: 25, + }, + { + name: "maxThreshold", + od: &overuseDetector{ + delayThreshold: 6, + lastUpdate: time.Time{}, + }, + ts: time.Time{}.Add(time.Second), + estimate: 6.1, + expectedThreshold: 6, + }, + { + name: "decrease", + od: &overuseDetector{ + delayThreshold: 12.5, + lastUpdate: time.Time{}, + }, + ts: time.Time{}.Add(time.Second), + estimate: 1, + expectedThreshold: 12.5, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.od.adaptThreshold(tc.ts, tc.estimate) + assert.Equal(t, tc.expectedThreshold, tc.od.delayThreshold) + }) + } +} diff --git a/pkg/bwe/rate_controller.go b/pkg/bwe/rate_controller.go new file mode 100644 index 00000000..e51a48c8 --- /dev/null +++ b/pkg/bwe/rate_controller.go @@ -0,0 +1,76 @@ +package bwe + +import ( + "math" + "time" +) + +type rateController struct { + s state + rate int + + decreaseFactor float64 // (beta) + lastUpdate time.Time + lastDecrease *exponentialMovingAverage +} + +func newRateController(initialRate int) *rateController { + return &rateController{ + s: stateIncrease, + rate: initialRate, + decreaseFactor: 0.85, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{}, + } +} + +func (c *rateController) update(ts time.Time, u usage, deliveredRate int, rtt time.Duration) int { + nextState := c.s.transition(u) + c.s = nextState + + if c.s == stateIncrease { + var target float64 + if c.canIncreaseMultiplicatively(float64(deliveredRate)) { + window := ts.Sub(c.lastUpdate) + target = c.multiplicativeIncrease(float64(c.rate), window) + } else { + bitsPerFrame := float64(c.rate) / 30.0 + packetsPerFrame := math.Ceil(bitsPerFrame / (1200 * 8)) + expectedPacketSizeBits := bitsPerFrame / packetsPerFrame + target = c.additiveIncrease(float64(c.rate), int(expectedPacketSizeBits), rtt) + } + c.rate = int(max(min(target, 1.5*float64(deliveredRate)), float64(c.rate))) + } + + if c.s == stateDecrease { + c.rate = int(c.decreaseFactor * float64(deliveredRate)) + c.lastDecrease.update(float64(c.rate)) + } + + c.lastUpdate = ts + + return c.rate +} + +func (c *rateController) canIncreaseMultiplicatively(deliveredRate float64) bool { + if c.lastDecrease.average == 0 { + return true + } + stdDev := math.Sqrt(c.lastDecrease.variance) + lower := c.lastDecrease.average - 3*stdDev + upper := c.lastDecrease.average + 3*stdDev + return deliveredRate < lower || deliveredRate > upper +} + +func (c *rateController) multiplicativeIncrease(rate float64, window time.Duration) float64 { + exponent := min(window.Seconds(), 1.0) + eta := math.Pow(1.08, exponent) + target := eta * rate + return target +} + +func (c *rateController) additiveIncrease(rate float64, expectedPacketSizeBits int, window time.Duration) float64 { + alpha := 0.5 * min(window.Seconds(), 1.0) + target := rate + max(1000, alpha*float64(expectedPacketSizeBits)) + return target +} diff --git a/pkg/bwe/rate_controller_test.go b/pkg/bwe/rate_controller_test.go new file mode 100644 index 00000000..7f647dc9 --- /dev/null +++ b/pkg/bwe/rate_controller_test.go @@ -0,0 +1,143 @@ +package bwe + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRateController(t *testing.T) { + cases := []struct { + name string + rc rateController + ts time.Time + u usage + delivered int + rtt time.Duration + expectedRate int + }{ + { + name: "zero", + rc: rateController{ + s: 0, + rate: 0, + decreaseFactor: 0, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{}, + }, + ts: time.Time{}, + u: 0, + delivered: 0, + rtt: 0, + expectedRate: 0, + }, + { + name: "multiplicativeIncrease", + rc: rateController{ + s: stateIncrease, + rate: 100, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{}, + }, + ts: time.Time{}.Add(time.Second), + u: usageNormal, + delivered: 100, + rtt: 0, + expectedRate: 108, + }, + { + name: "minimumAdditiveIncrease", + rc: rateController{ + s: stateIncrease, + rate: 100_000, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{ + average: 100_000, + }, + }, + ts: time.Time{}.Add(time.Second), + u: usageNormal, + delivered: 100_000, + rtt: 20 * time.Millisecond, + expectedRate: 101_000, + }, + { + name: "additiveIncrease", + rc: rateController{ + s: stateIncrease, + rate: 1_000_000, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{ + average: 1_000_000, + }, + }, + ts: time.Time{}.Add(time.Second), + u: usageNormal, + delivered: 1_000_000, + rtt: 2000 * time.Millisecond, + expectedRate: 1_004166, + }, + { + name: "minimumAdditiveIncreaseAppLimited", + rc: rateController{ + s: stateIncrease, + rate: 100_000, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{ + average: 100_000, + }, + }, + ts: time.Time{}.Add(time.Second), + u: usageNormal, + delivered: 50_000, + rtt: 20 * time.Millisecond, + expectedRate: 100_000, + }, + { + name: "additiveIncreaseAppLimited", + rc: rateController{ + s: stateIncrease, + rate: 1_000_000, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{ + average: 1_000_000, + }, + }, + ts: time.Time{}.Add(time.Second), + u: usageNormal, + delivered: 100_000, + rtt: 2000 * time.Millisecond, + expectedRate: 1_000_000, + }, + { + name: "decrease", + rc: rateController{ + s: stateDecrease, + rate: 1_000_000, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{ + average: 1_000_000, + }, + }, + ts: time.Time{}.Add(time.Second), + u: usageOver, + delivered: 1_000_000, + rtt: 2000 * time.Millisecond, + expectedRate: 900_000, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + res := tc.rc.update(tc.ts, tc.u, tc.delivered, tc.rtt) + assert.Equal(t, tc.expectedRate, res) + }) + } +} diff --git a/pkg/bwe/rtt.go b/pkg/bwe/rtt.go new file mode 100644 index 00000000..ad448f81 --- /dev/null +++ b/pkg/bwe/rtt.go @@ -0,0 +1,8 @@ +package bwe + +import "time" + +func MeasureRTT(reportSent, reportReceived, latestAckedSent, latestAckedArrival time.Time) time.Duration { + pendingTime := reportSent.Sub(latestAckedArrival) + return reportReceived.Sub(latestAckedSent) - pendingTime +} diff --git a/pkg/bwe/send_side_bwe.go b/pkg/bwe/send_side_bwe.go new file mode 100644 index 00000000..91855469 --- /dev/null +++ b/pkg/bwe/send_side_bwe.go @@ -0,0 +1,50 @@ +package bwe + +import ( + "time" + + "github.com/pion/logging" +) + +type SendSideController struct { + log logging.LeveledLogger + dre *deliveryRateEstimator + lbc *LossRateController + drc *DelayRateController + rate int +} + +func NewSendSideController(initialRate, minRate, maxRate int) *SendSideController { + return &SendSideController{ + log: logging.NewDefaultLoggerFactory().NewLogger("bwe_send_side_controller"), + dre: newDeliveryRateEstimator(time.Second), + lbc: NewLossRateController(initialRate, minRate, maxRate), + drc: NewDelayRateController(initialRate), + rate: initialRate, + } +} + +func (c *SendSideController) OnAcks(arrival time.Time, rtt time.Duration, acks []Acknowledgment) int { + if len(acks) == 0 { + return c.rate + } + + for _, ack := range acks { + if ack.Arrived { + c.lbc.OnPacketAcked() + if !ack.Arrival.IsZero() { + c.dre.OnPacketAcked(ack.Arrival, int(ack.Size)) + c.drc.OnPacketAcked(ack) + } + } else { + c.lbc.OnPacketLost() + } + } + + delivered := c.dre.GetRate() + lossTarget := c.lbc.Update(delivered) + delayTarget := c.drc.Update(arrival, delivered, rtt) + c.rate = min(lossTarget, delayTarget) + c.log.Tracef("rtt=%v, delivered=%v, lossTarget=%v, delayTarget=%v, target=%v", rtt.Nanoseconds(), delivered, lossTarget, delayTarget, c.rate) + return c.rate +} diff --git a/pkg/bwe/state.go b/pkg/bwe/state.go new file mode 100644 index 00000000..167a1f64 --- /dev/null +++ b/pkg/bwe/state.go @@ -0,0 +1,59 @@ +package bwe + +import "fmt" + +type state int + +const ( + stateDecrease state = -1 + stateHold state = 0 + stateIncrease state = 1 +) + +func (s state) transition(u usage) state { + switch s { + case stateHold: + switch u { + case usageOver: + return stateDecrease + case usageNormal: + return stateIncrease + case usageUnder: + return stateHold + } + + case stateIncrease: + switch u { + case usageOver: + return stateDecrease + case usageNormal: + return stateIncrease + case usageUnder: + return stateHold + } + + case stateDecrease: + switch u { + case usageOver: + return stateDecrease + case usageNormal: + return stateHold + case usageUnder: + return stateHold + } + } + return stateIncrease +} + +func (s state) String() string { + switch s { + case stateIncrease: + return "increase" + case stateDecrease: + return "decrease" + case stateHold: + return "hold" + default: + return fmt.Sprintf("invalid state: %d", s) + } +} diff --git a/pkg/bwe/state_test.go b/pkg/bwe/state_test.go new file mode 100644 index 00000000..ecc8d1a0 --- /dev/null +++ b/pkg/bwe/state_test.go @@ -0,0 +1,27 @@ +package bwe + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestState(t *testing.T) { + t.Run("hold", func(t *testing.T) { + assert.Equal(t, stateDecrease, stateHold.transition(usageOver)) + assert.Equal(t, stateIncrease, stateHold.transition(usageNormal)) + assert.Equal(t, stateHold, stateHold.transition(usageUnder)) + }) + + t.Run("increase", func(t *testing.T) { + assert.Equal(t, stateDecrease, stateIncrease.transition(usageOver)) + assert.Equal(t, stateIncrease, stateIncrease.transition(usageNormal)) + assert.Equal(t, stateHold, stateIncrease.transition(usageUnder)) + }) + + t.Run("decrease", func(t *testing.T) { + assert.Equal(t, stateDecrease, stateDecrease.transition(usageOver)) + assert.Equal(t, stateHold, stateDecrease.transition(usageNormal)) + assert.Equal(t, stateHold, stateDecrease.transition(usageUnder)) + }) +} diff --git a/pkg/bwe/usage.go b/pkg/bwe/usage.go new file mode 100644 index 00000000..7d0a0b48 --- /dev/null +++ b/pkg/bwe/usage.go @@ -0,0 +1,24 @@ +package bwe + +import "fmt" + +type usage int + +const ( + usageUnder usage = -1 + usageNormal usage = 0 + usageOver usage = 1 +) + +func (u usage) String() string { + switch u { + case usageOver: + return "overuse" + case usageUnder: + return "underuse" + case usageNormal: + return "normal" + default: + return fmt.Sprintf("invalid usage: %d", u) + } +} diff --git a/pkg/cc/interceptor.go b/pkg/cc/interceptor.go index 252ab29f..bbae7826 100644 --- a/pkg/cc/interceptor.go +++ b/pkg/cc/interceptor.go @@ -7,6 +7,7 @@ package cc import ( "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/ccfb" "github.com/pion/interceptor/pkg/gcc" "github.com/pion/rtcp" ) @@ -38,6 +39,7 @@ type InterceptorFactory struct { opts []Option bweFactory func() (BandwidthEstimator, error) addPeerConnection NewPeerConnectionCallback + ccfbFactory *ccfb.InterceptorFactory } // NewInterceptor returns a new CC interceptor factory @@ -47,10 +49,15 @@ func NewInterceptor(factory BandwidthEstimatorFactory, opts ...Option) (*Interce return gcc.NewSendSideBWE() } } + ccfbFactory, err := ccfb.NewInterceptor() + if err != nil { + return nil, err + } return &InterceptorFactory{ opts: opts, bweFactory: factory, addPeerConnection: nil, + ccfbFactory: ccfbFactory, }, nil } @@ -69,12 +76,10 @@ func (f *InterceptorFactory) NewInterceptor(id string) (interceptor.Interceptor, i := &Interceptor{ NoOp: interceptor.NoOp{}, estimator: bwe, - feedback: make(chan []rtcp.Packet), - close: make(chan struct{}), } for _, opt := range f.opts { - if err := opt(i); err != nil { + if err = opt(i); err != nil { return nil, err } } @@ -82,15 +87,18 @@ func (f *InterceptorFactory) NewInterceptor(id string) (interceptor.Interceptor, if f.addPeerConnection != nil { f.addPeerConnection(id, i.estimator) } - return i, nil + + ccfb, err := f.ccfbFactory.NewInterceptor(id) + if err != nil { + return nil, err + } + return interceptor.NewChain([]interceptor.Interceptor{ccfb, i}), nil } // Interceptor implements Google Congestion Control type Interceptor struct { interceptor.NoOp estimator BandwidthEstimator - feedback chan []rtcp.Packet - close chan struct{} } // BindRTCPReader lets you modify any incoming RTCP packets. It is called once diff --git a/pkg/ccfb/ccfb_receiver.go b/pkg/ccfb/ccfb_receiver.go new file mode 100644 index 00000000..dd11198c --- /dev/null +++ b/pkg/ccfb/ccfb_receiver.go @@ -0,0 +1,59 @@ +package ccfb + +import ( + "time" + + "github.com/pion/interceptor/internal/ntp" + "github.com/pion/rtcp" +) + +type acknowledgement struct { + seqNr uint16 + arrived bool + arrival time.Time + ecn rtcp.ECN +} + +func convertCCFB(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32][]acknowledgement) { + if feedback == nil { + return time.Time{}, nil + } + result := map[uint32][]acknowledgement{} + referenceTime := ntp.ToTime32(feedback.ReportTimestamp, ts) + for _, rb := range feedback.ReportBlocks { + result[rb.MediaSSRC] = convertMetricBlock(referenceTime, rb.BeginSequence, rb.MetricBlocks) + } + return referenceTime, result +} + +func convertMetricBlock(reference time.Time, seqNrOffset uint16, blocks []rtcp.CCFeedbackMetricBlock) []acknowledgement { + reports := make([]acknowledgement, len(blocks)) + for i, mb := range blocks { + if mb.Received { + arrival := time.Time{} + + // RFC 8888 states: If the measurement is unavailable or if the + // arrival time of the RTP packet is after the time represented by + // the RTS field, then an ATO value of 0x1FFF MUST be reported for + // the packet. In that case, we set a zero time.Time value. + if mb.ArrivalTimeOffset != 0x1FFF { + delta := time.Duration((float64(mb.ArrivalTimeOffset) / 1024.0) * float64(time.Second)) + arrival = reference.Add(-delta) + } + reports[i] = acknowledgement{ + seqNr: seqNrOffset + uint16(i), // nolint:gosec + arrived: true, + arrival: arrival, + ecn: mb.ECN, + } + } else { + reports[i] = acknowledgement{ + seqNr: seqNrOffset + uint16(i), // nolint:gosec + arrived: false, + arrival: time.Time{}, + ecn: 0, + } + } + } + return reports +} diff --git a/pkg/ccfb/ccfb_receiver_test.go b/pkg/ccfb/ccfb_receiver_test.go new file mode 100644 index 00000000..18a248c1 --- /dev/null +++ b/pkg/ccfb/ccfb_receiver_test.go @@ -0,0 +1,193 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/pion/interceptor/internal/ntp" + "github.com/pion/rtcp" + "github.com/stretchr/testify/assert" +) + +func TestConvertCCFB(t *testing.T) { + timeZero := time.Now() + cases := []struct { + ts time.Time + feedback *rtcp.CCFeedbackReport + expect map[uint32][]acknowledgement + expectTS time.Time + }{ + {}, + { + ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.CCFeedbackReport{ + SenderSSRC: 1, + ReportBlocks: []rtcp.CCFeedbackReportBlock{ + { + MediaSSRC: 2, + BeginSequence: 17, + MetricBlocks: []rtcp.CCFeedbackMetricBlock{ + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 512, + }, + }, + }, + }, + ReportTimestamp: ntp.ToNTP32(timeZero.Add(time.Second)), + }, + expect: map[uint32][]acknowledgement{ + 2: { + { + seqNr: 17, + arrived: true, + arrival: timeZero.Add(500 * time.Millisecond), + ecn: 0, + }, + }, + }, + expectTS: timeZero.Add(time.Second), + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + resTS, res := convertCCFB(tc.ts, tc.feedback) + + assert.InDelta(t, tc.expectTS.UnixNano(), resTS.UnixNano(), float64(time.Millisecond.Nanoseconds())) + + // Can't directly check equality since arrival timestamp conversions + // may be slightly off due to ntp conversions. + assert.Equal(t, len(tc.expect), len(res)) + for i, acks := range tc.expect { + for j, ack := range acks { + assert.Equal(t, ack.seqNr, res[i][j].seqNr) + assert.Equal(t, ack.arrived, res[i][j].arrived) + assert.Equal(t, ack.ecn, res[i][j].ecn) + assert.InDelta(t, ack.arrival.UnixNano(), res[i][j].arrival.UnixNano(), float64(time.Millisecond.Nanoseconds())) + } + } + }) + } +} + +func TestConvertMetricBlock(t *testing.T) { + cases := []struct { + ts time.Time + reference time.Time + seqNrOffset uint16 + blocks []rtcp.CCFeedbackMetricBlock + expected []acknowledgement + }{ + { + ts: time.Time{}, + reference: time.Time{}, + seqNrOffset: 0, + blocks: []rtcp.CCFeedbackMetricBlock{}, + expected: []acknowledgement{}, + }, + { + ts: time.Time{}.Add(2 * time.Second), + reference: time.Time{}.Add(time.Second), + seqNrOffset: 3, + blocks: []rtcp.CCFeedbackMetricBlock{ + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 512, + }, + { + Received: false, + ECN: 0, + ArrivalTimeOffset: 0, + }, + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 0, + }, + }, + expected: []acknowledgement{ + { + seqNr: 3, + arrived: true, + arrival: time.Time{}.Add(500 * time.Millisecond), + ecn: 0, + }, + { + seqNr: 4, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }, + { + seqNr: 5, + arrived: true, + arrival: time.Time{}.Add(time.Second), + ecn: 0, + }, + }, + }, + { + ts: time.Time{}.Add(2 * time.Second), + reference: time.Time{}.Add(time.Second), + seqNrOffset: 3, + blocks: []rtcp.CCFeedbackMetricBlock{ + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 512, + }, + { + Received: false, + ECN: 0, + ArrivalTimeOffset: 0, + }, + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 0, + }, + { + Received: true, + ECN: 0, + ArrivalTimeOffset: 0x1FFF, + }, + }, + expected: []acknowledgement{ + { + seqNr: 3, + arrived: true, + arrival: time.Time{}.Add(500 * time.Millisecond), + ecn: 0, + }, + { + seqNr: 4, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }, + { + seqNr: 5, + arrived: true, + arrival: time.Time{}.Add(time.Second), + ecn: 0, + }, + { + seqNr: 6, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }, + }, + }, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := convertMetricBlock(tc.reference, tc.seqNrOffset, tc.blocks) + assert.Equal(t, tc.expected, res) + }) + } +} diff --git a/pkg/ccfb/duplicate_ack_filter.go b/pkg/ccfb/duplicate_ack_filter.go new file mode 100644 index 00000000..79f8f6db --- /dev/null +++ b/pkg/ccfb/duplicate_ack_filter.go @@ -0,0 +1,29 @@ +package ccfb + +// DuplicateAckFilter is a helper to remove duplicate acks from a Report. +type DuplicateAckFilter struct { + highestAckedBySSRC map[uint32]int64 +} + +// NewDuplicateAckFilter creates a new DuplicateAckFilter +func NewDuplicateAckFilter() *DuplicateAckFilter { + return &DuplicateAckFilter{ + highestAckedBySSRC: make(map[uint32]int64), + } +} + +// Filter filters duplicate acks. It filters out all acks for packets with a +// sequence number smaller than the highest seen sequence number for each SSRC. +func (f *DuplicateAckFilter) Filter(reports Report) { + for ssrc, prs := range reports.SSRCToPacketReports { + n := 0 + for _, report := range prs { + if highest, ok := f.highestAckedBySSRC[ssrc]; !ok || report.SeqNr > highest { + f.highestAckedBySSRC[ssrc] = report.SeqNr + prs[n] = report + n++ + } + } + reports.SSRCToPacketReports[ssrc] = prs[:n] + } +} diff --git a/pkg/ccfb/duplicate_ack_filter_test.go b/pkg/ccfb/duplicate_ack_filter_test.go new file mode 100644 index 00000000..20e4d6f8 --- /dev/null +++ b/pkg/ccfb/duplicate_ack_filter_test.go @@ -0,0 +1,106 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDuplicateAckFilter(t *testing.T) { + cases := []struct { + in []Report + expect []Report + }{ + { + in: []Report{}, + expect: []Report{}, + }, + { + in: []Report{ + { + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: {}, + }, + }, + }, + expect: []Report{ + { + Arrival: time.Time{}, + Departure: time.Time{}, + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: {}, + }, + }, + }, + }, + { + in: []Report{ + { + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: { + { + SeqNr: 1, + }, + { + SeqNr: 2, + }, + }, + }, + }, + { + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: { + { + SeqNr: 1, + }, + { + SeqNr: 2, + }, + { + SeqNr: 3, + }, + }, + }, + }, + }, + expect: []Report{ + { + Arrival: time.Time{}, + Departure: time.Time{}, + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: { + { + SeqNr: 1, + }, + { + SeqNr: 2, + }, + }, + }, + }, + { + Arrival: time.Time{}, + Departure: time.Time{}, + SSRCToPacketReports: map[uint32][]PacketReport{ + 0: { + { + SeqNr: 3, + }, + }, + }, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + daf := NewDuplicateAckFilter() + for i, m := range tc.in { + daf.Filter(m) + assert.Equal(t, tc.expect[i], m) + } + }) + } +} diff --git a/pkg/ccfb/history.go b/pkg/ccfb/history.go new file mode 100644 index 00000000..9c144702 --- /dev/null +++ b/pkg/ccfb/history.go @@ -0,0 +1,110 @@ +package ccfb + +import ( + "container/list" + "errors" + "sync" + "time" + + "github.com/pion/interceptor/internal/sequencenumber" + "github.com/pion/rtcp" +) + +var errSequenceNumberWentBackwards = errors.New("sequence number went backwards") + +// PacketReport contains departure and arrival information about an acknowledged +// packet. +type PacketReport struct { + SeqNr int64 + Size int + Departure time.Time + Arrived bool + Arrival time.Time + ECN rtcp.ECN +} + +type sentPacket struct { + seqNr int64 + size int + departure time.Time +} + +type historyList struct { + lock sync.Mutex + size int + evictList *list.List + seqNrToPacket map[int64]*list.Element + sentSeqNr *sequencenumber.Unwrapper + ackedSeqNr *sequencenumber.Unwrapper +} + +func newHistoryList(size int) *historyList { + return &historyList{ + lock: sync.Mutex{}, + size: size, + evictList: list.New(), + seqNrToPacket: make(map[int64]*list.Element), + sentSeqNr: &sequencenumber.Unwrapper{}, + ackedSeqNr: &sequencenumber.Unwrapper{}, + } +} + +func (h *historyList) add(seqNr uint16, size int, departure time.Time) error { + h.lock.Lock() + defer h.lock.Unlock() + + sn := h.sentSeqNr.Unwrap(seqNr) + last := h.evictList.Back() + if last != nil { + if p, ok := last.Value.(sentPacket); ok && sn < p.seqNr { + return errSequenceNumberWentBackwards + } + } + ent := h.evictList.PushBack(sentPacket{ + seqNr: sn, + size: size, + departure: departure, + }) + h.seqNrToPacket[sn] = ent + + if h.evictList.Len() > h.size { + h.removeOldest() + } + return nil +} + +// Must be called while holding the lock +func (h *historyList) removeOldest() { + if ent := h.evictList.Front(); ent != nil { + v := h.evictList.Remove(ent) + if sp, ok := v.(sentPacket); ok { + delete(h.seqNrToPacket, sp.seqNr) + } + } +} + +func (h *historyList) getReportForAck(acks []acknowledgement) []PacketReport { + h.lock.Lock() + defer h.lock.Unlock() + + reports := []PacketReport{} + for _, pr := range acks { + sn := h.ackedSeqNr.Unwrap(pr.seqNr) + ent, ok := h.seqNrToPacket[sn] + // Ignore report for unknown packets (migth have been dropped from + // history) + if ok { + if ack, ok := ent.Value.(sentPacket); ok { + reports = append(reports, PacketReport{ + SeqNr: sn, + Size: ack.size, + Departure: ack.departure, + Arrived: pr.arrived, + Arrival: pr.arrival, + ECN: pr.ecn, + }) + } + } + } + return reports +} diff --git a/pkg/ccfb/history_test.go b/pkg/ccfb/history_test.go new file mode 100644 index 00000000..c500242e --- /dev/null +++ b/pkg/ccfb/history_test.go @@ -0,0 +1,114 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestHistory(t *testing.T) { + t.Run("errorOnDecreasingSeqNr", func(t *testing.T) { + h := newHistoryList(200) + assert.NoError(t, h.add(10, 1200, time.Now())) + assert.NoError(t, h.add(11, 1200, time.Now())) + assert.Error(t, h.add(9, 1200, time.Now())) + }) + + t.Run("getReportForAck", func(t *testing.T) { + cases := []struct { + outgoing []struct { + seqNr uint16 + size int + ts time.Time + } + acks []acknowledgement + expectedReport []PacketReport + expectedHistorySize int + }{ + { + outgoing: []struct { + seqNr uint16 + size int + ts time.Time + }{}, + acks: []acknowledgement{}, + expectedReport: []PacketReport{}, + expectedHistorySize: 0, + }, + { + outgoing: []struct { + seqNr uint16 + size int + ts time.Time + }{ + {0, 1200, time.Time{}.Add(1 * time.Millisecond)}, + {1, 1200, time.Time{}.Add(2 * time.Millisecond)}, + {2, 1200, time.Time{}.Add(3 * time.Millisecond)}, + {3, 1200, time.Time{}.Add(4 * time.Millisecond)}, + }, + acks: []acknowledgement{}, + expectedReport: []PacketReport{}, + expectedHistorySize: 4, + }, + { + outgoing: []struct { + seqNr uint16 + size int + ts time.Time + }{ + {0, 1200, time.Time{}.Add(1 * time.Millisecond)}, + {1, 1200, time.Time{}.Add(2 * time.Millisecond)}, + {2, 1200, time.Time{}.Add(3 * time.Millisecond)}, + {3, 1200, time.Time{}.Add(4 * time.Millisecond)}, + }, + acks: []acknowledgement{ + {1, true, time.Time{}.Add(3 * time.Millisecond), 0}, + {2, false, time.Time{}, 0}, + {3, true, time.Time{}.Add(5 * time.Millisecond), 0}, + }, + expectedReport: []PacketReport{ + {1, 1200, time.Time{}.Add(2 * time.Millisecond), true, time.Time{}.Add(3 * time.Millisecond), 0}, + {2, 1200, time.Time{}.Add(3 * time.Millisecond), false, time.Time{}, 0}, + {3, 1200, time.Time{}.Add(4 * time.Millisecond), true, time.Time{}.Add(5 * time.Millisecond), 0}, + }, + expectedHistorySize: 4, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + h := newHistoryList(200) + for _, op := range tc.outgoing { + assert.NoError(t, h.add(op.seqNr, op.size, op.ts)) + } + prl := h.getReportForAck(tc.acks) + assert.Equal(t, tc.expectedReport, prl) + assert.Equal(t, tc.expectedHistorySize, len(h.seqNrToPacket)) + assert.Equal(t, tc.expectedHistorySize, h.evictList.Len()) + }) + } + }) + + t.Run("garbageCollection", func(t *testing.T) { + h := newHistoryList(200) + + for i := uint16(0); i < 300; i++ { + assert.NoError(t, h.add(i, 1200, time.Time{}.Add(time.Duration(i)*time.Millisecond))) + } + + acks := []acknowledgement{} + for i := uint16(200); i < 290; i++ { + acks = append(acks, acknowledgement{ + seqNr: i, + arrived: true, + arrival: time.Time{}.Add(time.Duration(500+i) * time.Millisecond), + ecn: 0, + }) + } + prl := h.getReportForAck(acks) + assert.Len(t, prl, 90) + assert.Equal(t, 200, len(h.seqNrToPacket)) + assert.Equal(t, 200, h.evictList.Len()) + }) +} diff --git a/pkg/ccfb/interceptor.go b/pkg/ccfb/interceptor.go new file mode 100644 index 00000000..001fdaae --- /dev/null +++ b/pkg/ccfb/interceptor.go @@ -0,0 +1,231 @@ +// Package ccfb implements feedback aggregation for CCFB and TWCC packets. +package ccfb + +import ( + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" + +type ccfbAttributesKeyType uint32 + +// CCFBAttributesKey is the key which can be used to retrieve the Report objects +// from the interceptor.Attributes +const CCFBAttributesKey ccfbAttributesKeyType = iota + +// A Report contains Arrival and Departure (from the remote end) times of a RTCP +// feedback packet (CCFB or TWCC) and a list of PacketReport for all +// acknowledged packets that were still in the history. +type Report struct { + Arrival time.Time + Departure time.Time + SSRCToPacketReports map[uint32][]PacketReport +} + +type history interface { + add(seqNr uint16, size int, departure time.Time) error + getReportForAck([]acknowledgement) []PacketReport +} + +// Option can be used to set initial options on CCFB interceptors +type Option func(*Interceptor) error + +// HistorySize sets the size of the history of outgoing packets. +func HistorySize(size int) Option { + return func(i *Interceptor) error { + i.historySize = size + return nil + } +} + +func timeFactory(f func() time.Time) Option { + return func(i *Interceptor) error { + i.timestamp = f + return nil + } +} + +func historyFactory(f func(int) history) Option { + return func(i *Interceptor) error { + i.historyFactory = f + return nil + } +} + +func ccfbConverterFactory(f func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32][]acknowledgement)) Option { + return func(i *Interceptor) error { + i.convertCCFB = f + return nil + } +} + +func twccConverterFactory(f func(feedback *rtcp.TransportLayerCC) (time.Time, map[uint32][]acknowledgement)) Option { + return func(i *Interceptor) error { + i.convertTWCC = f + return nil + } +} + +// InterceptorFactory is a factory for CCFB interceptors +type InterceptorFactory struct { + opts []Option +} + +// NewInterceptor returns a new CCFB InterceptorFactory +func NewInterceptor(opts ...Option) (*InterceptorFactory, error) { + return &InterceptorFactory{ + opts: opts, + }, nil +} + +// NewInterceptor returns a new ccfb.Interceptor +func (f *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + i := &Interceptor{ + NoOp: interceptor.NoOp{}, + lock: sync.Mutex{}, + log: logging.NewDefaultLoggerFactory().NewLogger("ccfb_interceptor"), + timestamp: time.Now, + convertCCFB: convertCCFB, + convertTWCC: convertTWCC, + ssrcToHistory: make(map[uint32]history), + historySize: 200, + historyFactory: func(size int) history { + return newHistoryList(size) + }, + } + for _, opt := range f.opts { + if err := opt(i); err != nil { + return nil, err + } + } + return i, nil +} + +// Interceptor implements a congestion control feedback receiver. It keeps track +// of outgoing packets and reads incoming feedback reports (CCFB or TWCC). For +// each incoming feedback report, it will add an entry to the interceptor +// attributes, which can be read from the `RTCPReader` +// (`webrtc.RTPSender.Read`). For each acknowledgement included in the feedback +// report and for which there still is an entry in the history of outgoing +// packets, a PacketReport will be added to the ccfb.Report map. The map +// contains a list of packets for each outgoing SSRC if CCFB is used. The map +// contains a single entry with SSRC=0 if TWCC is used. +type Interceptor struct { + interceptor.NoOp + lock sync.Mutex + log logging.LeveledLogger + timestamp func() time.Time + convertCCFB func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32][]acknowledgement) + convertTWCC func(feedback *rtcp.TransportLayerCC) (time.Time, map[uint32][]acknowledgement) + ssrcToHistory map[uint32]history + historySize int + historyFactory func(int) history +} + +// BindLocalStream implements interceptor.Interceptor. +func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + var twccHdrExtID uint8 + var useTWCC bool + for _, e := range info.RTPHeaderExtensions { + if e.URI == transportCCURI { + twccHdrExtID = uint8(e.ID) // nolint:gosec + useTWCC = true + break + } + } + + i.lock.Lock() + defer i.lock.Unlock() + + ssrc := info.SSRC + if useTWCC { + ssrc = 0 + } + i.ssrcToHistory[ssrc] = i.historyFactory(i.historySize) + + return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + i.lock.Lock() + defer i.lock.Unlock() + + // If we are using TWCC, we use the sequence number from the TWCC header + // extension and save all TWCC sequence numbers with the same SSRC (0). + // If we are not using TWCC, we save a history per SSRC and use the + // normal RTP sequence numbers. + ssrc := header.SSRC + seqNr := header.SequenceNumber + if useTWCC { + var twccHdrExt rtp.TransportCCExtension + if err := twccHdrExt.Unmarshal(header.GetExtension(twccHdrExtID)); err != nil { + i.log.Warnf("CCFB configured for TWCC, but failed to get TWCC header extension from outgoing packet. Falling back to saving history for CCFB feedback reports. err: %v", err) + if _, ok := i.ssrcToHistory[ssrc]; !ok { + i.ssrcToHistory[ssrc] = i.historyFactory(i.historySize) + } + } else { + seqNr = twccHdrExt.TransportSequence + ssrc = 0 + } + } + if err := i.ssrcToHistory[ssrc].add(seqNr, header.MarshalSize()+len(payload), i.timestamp()); err != nil { + return 0, err + } + return writer.Write(header, payload, attributes) + }) +} + +// BindRTCPReader implements interceptor.Interceptor. +func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { + return interceptor.RTCPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + n, attr, err := reader.Read(b, a) + if err != nil { + return n, attr, err + } + now := i.timestamp() + + buf := make([]byte, n) + copy(buf, b[:n]) + + if attr == nil { + attr = make(interceptor.Attributes) + } + + res := []Report{} + + pkts, err := attr.GetRTCPPackets(buf) + if err != nil { + return n, attr, err + } + for _, pkt := range pkts { + var reportLists map[uint32][]acknowledgement + var reportDeparture time.Time + switch fb := pkt.(type) { + case *rtcp.CCFeedbackReport: + reportDeparture, reportLists = i.convertCCFB(now, fb) + case *rtcp.TransportLayerCC: + reportDeparture, reportLists = i.convertTWCC(fb) + default: + } + ssrcToPrl := map[uint32][]PacketReport{} + for ssrc, reportList := range reportLists { + prl := i.ssrcToHistory[ssrc].getReportForAck(reportList) + if _, ok := ssrcToPrl[ssrc]; !ok { + ssrcToPrl[ssrc] = prl + } else { + ssrcToPrl[ssrc] = append(ssrcToPrl[ssrc], prl...) + } + } + res = append(res, Report{ + Arrival: now, + Departure: reportDeparture, + SSRCToPacketReports: ssrcToPrl, + }) + } + attr.Set(CCFBAttributesKey, res) + return n, attr, err + }) +} diff --git a/pkg/ccfb/interceptor_test.go b/pkg/ccfb/interceptor_test.go new file mode 100644 index 00000000..4c912fdb --- /dev/null +++ b/pkg/ccfb/interceptor_test.go @@ -0,0 +1,327 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/test" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/stretchr/testify/assert" +) + +type mockHistoryAddEntry struct { + seqNr uint16 + size int + departure time.Time +} + +type mockHistory struct { + added []mockHistoryAddEntry + report []PacketReport +} + +// add implements history. +func (m *mockHistory) add(seqNr uint16, size int, departure time.Time) error { + m.added = append(m.added, mockHistoryAddEntry{ + seqNr: seqNr, + size: size, + departure: departure, + }) + return nil +} + +// getReportForAck implements history. +func (m *mockHistory) getReportForAck(_ []acknowledgement) []PacketReport { + return m.report +} + +func TestInterceptor(t *testing.T) { + mockTimestamp := time.Time{}.Add(17 * time.Second) + t.Run("writeRTP", func(t *testing.T) { + type addPkt struct { + pkt *rtp.Packet + ext *rtp.TransportCCExtension + } + cases := []struct { + add []addPkt + twcc bool + expect *mockHistory + }{ + { + add: []addPkt{}, + expect: &mockHistory{ + added: []mockHistoryAddEntry{}, + }, + }, + { + add: []addPkt{ + { + pkt: &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 137, + }, + }, + }, + }, + expect: &mockHistory{ + added: []mockHistoryAddEntry{ + {137, 12, mockTimestamp}, + }, + }, + }, + { + add: []addPkt{ + { + pkt: &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 137, + }, + }, + ext: &rtp.TransportCCExtension{ + TransportSequence: 16, + }, + }, + }, + twcc: true, + expect: &mockHistory{ + added: []mockHistoryAddEntry{ + {16, 20, mockTimestamp}, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + mt := func() time.Time { + return mockTimestamp + } + mh := &mockHistory{ + added: []mockHistoryAddEntry{}, + } + f, err := NewInterceptor( + historyFactory(func(_ int) history { + return mh + }), + timeFactory(mt), + ) + assert.NoError(t, err) + + i, err := f.NewInterceptor("") + assert.NoError(t, err) + + info := &interceptor.StreamInfo{} + if tc.twcc { + info.RTPHeaderExtensions = append(info.RTPHeaderExtensions, interceptor.RTPHeaderExtension{ + URI: transportCCURI, + ID: 2, + }) + } + stream := test.NewMockStream(info, i) + + for _, pkt := range tc.add { + if pkt.ext != nil { + ext, err := pkt.ext.Marshal() + assert.NoError(t, err) + assert.NoError(t, pkt.pkt.SetExtension(2, ext)) + } + assert.NoError(t, stream.WriteRTP(pkt.pkt)) + } + + assert.Equal(t, tc.expect, mh) + }) + } + }) + + t.Run("missingTWCCHeaderExtension", func(t *testing.T) { + mt := func() time.Time { + return mockTimestamp + } + mh := &mockHistory{ + added: []mockHistoryAddEntry{}, + } + f, err := NewInterceptor( + historyFactory(func(_ int) history { + return mh + }), + timeFactory(mt), + ) + assert.NoError(t, err) + + i, err := f.NewInterceptor("") + assert.NoError(t, err) + + info := &interceptor.StreamInfo{} + info.RTPHeaderExtensions = append(info.RTPHeaderExtensions, interceptor.RTPHeaderExtension{ + URI: transportCCURI, + ID: 2, + }) + stream := test.NewMockStream(info, i) + + err = stream.WriteRTP(&rtp.Packet{ + Header: rtp.Header{SequenceNumber: 3}, + Payload: []byte{}, + }) + assert.NoError(t, err) + assert.Equal(t, mh.added, []mockHistoryAddEntry{{ + seqNr: 3, + size: 12, + departure: mockTimestamp, + }}) + }) + + t.Run("readRTCP", func(t *testing.T) { + cases := []struct { + mh *mockHistory + rtcp rtcp.Packet + }{ + { + mh: &mockHistory{ + report: []PacketReport{}, + }, + rtcp: &rtcp.CCFeedbackReport{}, + }, + { + mh: &mockHistory{ + report: []PacketReport{ + { + SeqNr: 3, + Size: 12, + Departure: mockTimestamp, + Arrived: true, + Arrival: mockTimestamp, + ECN: 0, + }, + }, + }, + rtcp: &rtcp.CCFeedbackReport{}, + }, + { + mh: &mockHistory{ + report: []PacketReport{}, + }, + rtcp: &rtcp.TransportLayerCC{ + Header: rtcp.Header{ + Padding: false, + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + Length: 6, + }, + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 3, + PacketStatusCount: 0, + ReferenceTime: 5, + FbPktCount: 6, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + Type: rtcp.RunLengthChunkType, + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + }, + }, + }, + { + mh: &mockHistory{ + report: []PacketReport{ + { + SeqNr: 3, + Size: 12, + Departure: mockTimestamp, + Arrived: true, + Arrival: mockTimestamp, + ECN: 0, + }, + }, + }, + rtcp: &rtcp.TransportLayerCC{ + Header: rtcp.Header{ + Padding: false, + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + Length: 6, + }, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 0, + PacketStatusCount: 0, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + Type: rtcp.RunLengthChunkType, + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + {Type: 0, Delta: 0}, + }, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + mt := func() time.Time { + return mockTimestamp + } + mockCCFBConverter := func(_ time.Time, _ *rtcp.CCFeedbackReport) (time.Time, map[uint32][]acknowledgement) { + return mockTimestamp, map[uint32][]acknowledgement{ + 0: {}, + } + } + mockTWCCConverter := func(_ *rtcp.TransportLayerCC) (time.Time, map[uint32][]acknowledgement) { + return mockTimestamp, map[uint32][]acknowledgement{ + 0: {}, + } + } + f, err := NewInterceptor( + historyFactory(func(_ int) history { + return tc.mh + }), + timeFactory(mt), + ccfbConverterFactory(mockCCFBConverter), + twccConverterFactory(mockTWCCConverter), + ) + assert.NoError(t, err) + + i, err := f.NewInterceptor("") + assert.NoError(t, err) + + info := &interceptor.StreamInfo{} + if _, ok := tc.rtcp.(*rtcp.TransportLayerCC); ok { + info.RTPHeaderExtensions = append(info.RTPHeaderExtensions, interceptor.RTPHeaderExtension{ + URI: transportCCURI, + ID: 2, + }) + } + stream := test.NewMockStream(info, i) + + stream.ReceiveRTCP([]rtcp.Packet{tc.rtcp}) + + report := <-stream.ReadRTCP() + + assert.NoError(t, report.Err) + + prlsInterface, ok := report.Attr[CCFBAttributesKey] + assert.True(t, ok) + prls, ok := prlsInterface.([]Report) + assert.True(t, ok) + assert.Len(t, prls, 1) + assert.Equal(t, tc.mh.report, prls[0].SSRCToPacketReports[0]) + }) + } + }) +} diff --git a/pkg/ccfb/twcc_receiver.go b/pkg/ccfb/twcc_receiver.go new file mode 100644 index 00000000..98af8bde --- /dev/null +++ b/pkg/ccfb/twcc_receiver.go @@ -0,0 +1,88 @@ +package ccfb + +import ( + "time" + + "github.com/pion/rtcp" +) + +func convertTWCC(feedback *rtcp.TransportLayerCC) (time.Time, map[uint32][]acknowledgement) { + if feedback == nil { + return time.Time{}, nil + } + var acks []acknowledgement + + nextTimestamp := time.Time{}.Add(time.Duration(feedback.ReferenceTime) * 64 * time.Millisecond) + reportDeparture := nextTimestamp + recvDeltaIndex := 0 + + offset := 0 + for _, pc := range feedback.PacketChunks { + switch chunk := pc.(type) { + case *rtcp.RunLengthChunk: + for i := uint16(0); i < chunk.RunLength; i++ { + seqNr := feedback.BaseSequenceNumber + uint16(offset) // nolint:gosec + offset++ + switch chunk.PacketStatusSymbol { + case rtcp.TypeTCCPacketNotReceived: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta: + delta := feedback.RecvDeltas[recvDeltaIndex] + nextTimestamp = nextTimestamp.Add(time.Duration(delta.Delta) * time.Microsecond) + recvDeltaIndex++ + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: nextTimestamp, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedWithoutDelta: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }) + } + } + case *rtcp.StatusVectorChunk: + for _, s := range chunk.SymbolList { + seqNr := feedback.BaseSequenceNumber + uint16(offset) // nolint:gosec + offset++ + switch s { + case rtcp.TypeTCCPacketNotReceived: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: false, + arrival: time.Time{}, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedSmallDelta, rtcp.TypeTCCPacketReceivedLargeDelta: + delta := feedback.RecvDeltas[recvDeltaIndex] + nextTimestamp = nextTimestamp.Add(time.Duration(delta.Delta) * time.Microsecond) + recvDeltaIndex++ + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: nextTimestamp, + ecn: 0, + }) + case rtcp.TypeTCCPacketReceivedWithoutDelta: + acks = append(acks, acknowledgement{ + seqNr: seqNr, + arrived: true, + arrival: time.Time{}, + ecn: 0, + }) + } + } + } + } + + return reportDeparture, map[uint32][]acknowledgement{0: acks} +} diff --git a/pkg/ccfb/twcc_receiver_test.go b/pkg/ccfb/twcc_receiver_test.go new file mode 100644 index 00000000..8c820041 --- /dev/null +++ b/pkg/ccfb/twcc_receiver_test.go @@ -0,0 +1,125 @@ +package ccfb + +import ( + "fmt" + "testing" + "time" + + "github.com/pion/rtcp" + "github.com/stretchr/testify/assert" +) + +func TestConvertTWCC(t *testing.T) { + // timeZero := time.Now() + cases := []struct { + feedback *rtcp.TransportLayerCC + expect map[uint32][]acknowledgement + expectTS time.Time + }{ + {}, + { + // ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.TransportLayerCC{ + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 178, + PacketStatusCount: 0, + ReferenceTime: 3, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{}, + RecvDeltas: []*rtcp.RecvDelta{}, + }, + expect: map[uint32][]acknowledgement{ + 0: nil, + }, + expectTS: time.Time{}.Add(3 * 64 * time.Millisecond), + }, + { + // ts: timeZero.Add(2 * time.Second), + feedback: &rtcp.TransportLayerCC{ + SenderSSRC: 1, + MediaSSRC: 2, + BaseSequenceNumber: 178, + PacketStatusCount: 18, + ReferenceTime: 3, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 1000}, + {Type: rtcp.TypeTCCPacketReceivedLargeDelta, Delta: 1000}, + }, + }, + expect: map[uint32][]acknowledgement{ + 0: { + // first run length chunk + {seqNr: 178, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 1*time.Millisecond), ecn: 0}, + {seqNr: 179, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 2*time.Millisecond), ecn: 0}, + {seqNr: 180, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 3*time.Millisecond), ecn: 0}, + + // first status vector chunk + {seqNr: 181, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 4*time.Millisecond), ecn: 0}, + {seqNr: 182, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 5*time.Millisecond), ecn: 0}, + {seqNr: 183, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 6*time.Millisecond), ecn: 0}, + {seqNr: 184, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 185, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 186, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 187, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 188, arrived: false, arrival: time.Time{}, ecn: 0}, + + // second status vector chunk + {seqNr: 189, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 7*time.Millisecond), ecn: 0}, + {seqNr: 190, arrived: true, arrival: time.Time{}.Add(3*64*time.Millisecond + 8*time.Millisecond), ecn: 0}, + {seqNr: 191, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 192, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 193, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 194, arrived: false, arrival: time.Time{}, ecn: 0}, + {seqNr: 195, arrived: false, arrival: time.Time{}, ecn: 0}, + }, + }, + expectTS: time.Time{}.Add(3 * 64 * time.Millisecond), + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + resTS, res := convertTWCC(tc.feedback) + assert.Equal(t, tc.expect, res) + assert.Equal(t, tc.expectTS, resTS) + }) + } +}