Skip to content

Commit

Permalink
WIP - improve start and end of speech detection
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler committed Jan 9, 2025
1 parent a6c09ac commit 2d70ef6
Showing 1 changed file with 127 additions and 120 deletions.
247 changes: 127 additions & 120 deletions core/http/endpoints/openai/realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,156 +497,163 @@ type VADState int
const (
StateSilence VADState = iota
StateSpeaking
StateTrailingSilence
)

// handle VAD (Voice Activity Detection)
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, c *websocket.Conn, done chan struct{}) {
const (
// tune these thresholds to taste
SpeechFramesThreshold = 3 // must see X consecutive speech results to confirm "start"
SilenceFramesThreshold = 5 // must see X consecutive silence results to confirm "end"
)

// handleVAD is a goroutine that listens for audio data from the client,
// runs VAD on the audio data, and commits utterances to the conversation
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
vadContext, cancel := context.WithCancel(context.Background())
//var startListening time.Time

go func() {
<-done
cancel()
}()

vadState := VADState(StateSilence)
segments := []*proto.VADSegment{}
timeListening := time.Now()
ticker := time.NewTicker(300 * time.Millisecond)
defer ticker.Stop()

var (
lastSegmentCount int
timeOfLastNewSeg time.Time
speaking bool
)

// Implement VAD logic here
// For brevity, this is a placeholder
// When VAD detects end of speech, generate a response
// TODO: use session.ModelInterface to handle VAD and cut audio and detect when to process that
for {
select {
case <-done:
return
default:
// Check if there's audio data to process
case <-ticker.C:
// 1) Copy the entire buffer
session.AudioBufferLock.Lock()
allAudio := make([]byte, len(session.InputAudioBuffer))
copy(allAudio, session.InputAudioBuffer)
session.AudioBufferLock.Unlock()

if len(session.InputAudioBuffer) > 0 {

if vadState == StateTrailingSilence {
log.Debug().Msgf("VAD detected speech that we can process")

// Commit the audio buffer as a conversation item
item := &Item{
ID: generateItemID(),
Object: "realtime.item",
Type: "message",
Status: "completed",
Role: "user",
Content: []ConversationContent{
{
Type: "input_audio",
Audio: base64.StdEncoding.EncodeToString(session.InputAudioBuffer),
},
},
}
// 2) If there's no audio at all, just continue
if len(allAudio) == 0 {
continue
}

// Add item to conversation
conversation.Lock.Lock()
conversation.Items = append(conversation.Items, item)
conversation.Lock.Unlock()

// Reset InputAudioBuffer
session.InputAudioBuffer = nil
session.AudioBufferLock.Unlock()

// Send item.created event
sendEvent(c, OutgoingMessage{
Type: "conversation.item.created",
Item: item,
})

vadState = StateSilence
segments = []*proto.VADSegment{}
// Generate a response
generateResponse(cfg, evaluator, session, conversation, ResponseCreate{}, c, websocket.TextMessage)
continue
}
// 3) Run VAD on the entire audio so far
segments, err := runVAD(vadContext, session, allAudio)
if err != nil {
log.Error().Msgf("failed to process audio: %s", err.Error())
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
// handle or log error, continue
continue
}

adata := sound.BytesToInt16sLE(session.InputAudioBuffer)
segCount := len(segments)

// Resample from 24kHz to 16kHz
// adata = sound.ResampleInt16(adata, 24000, 16000)
if len(segments) == 0 && !speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
// no speech detected, and we haven't seen a new segment in > 1s
// clean up input
session.AudioBufferLock.Lock()
session.InputAudioBuffer = nil
session.AudioBufferLock.Unlock()
log.Debug().Msgf("Detected silence for a while, clearing audio buffer")
continue
}

soundIntBuffer := &audio.IntBuffer{
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
}
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)
// 4) If we see more segments than before => "new speech"
if segCount > lastSegmentCount {
speaking = true
lastSegmentCount = segCount
timeOfLastNewSeg = time.Now()
log.Debug().Msgf("Detected new speech segment")
}

/* if len(adata) < 16000 {
log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer))
session.AudioBufferLock.Unlock()
continue
} */
float32Data := soundIntBuffer.AsFloat32Buffer().Data

// TODO: testing wav decoding
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
// buf, err := dec.FullPCMBuffer()
// if err != nil {
// //log.Error().Msgf("failed to process audio: %s", err.Error())
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
// session.AudioBufferLock.Unlock()
// continue
// }

//float32Data = buf.AsFloat32Buffer().Data

resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{
Audio: float32Data,
})
if err != nil {
log.Error().Msgf("failed to process audio: %s", err.Error())
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
session.AudioBufferLock.Unlock()
continue
}
// 5) If speaking, but we haven't seen a new segment in > 1s => finalize
if speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
log.Debug().Msgf("Detected end of speech segment")
// user has presumably stopped talking
commitUtterance(allAudio, cfg, evaluator, session, conv, c)
// reset state
speaking = false
lastSegmentCount = 0
}
}
}
}

if len(resp.Segments) == 0 {
log.Debug().Msg("VAD detected no speech activity")
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
if len(session.InputAudioBuffer) > 16000 {
session.InputAudioBuffer = nil
segments = []*proto.VADSegment{}
}
func commitUtterance(utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
if len(utt) == 0 {
return
}
// Commit logic: create item, broadcast item.created, etc.
item := &Item{
ID: generateItemID(),
Object: "realtime.item",
Type: "message",
Status: "completed",
Role: "user",
Content: []ConversationContent{
{
Type: "input_audio",
Audio: base64.StdEncoding.EncodeToString(utt),
},
},
}
conv.Lock.Lock()
conv.Items = append(conv.Items, item)
conv.Lock.Unlock()

log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer))
} else if (len(resp.Segments) != len(segments)) && vadState == StateSpeaking {
// We have new segments, but we are still speaking
// We need to wait for the trailing silence
sendEvent(c, OutgoingMessage{
Type: "conversation.item.created",
Item: item,
})

segments = resp.Segments
// Optionally trigger the response generation
generateResponse(cfg, evaluator, session, conv, ResponseCreate{}, c, websocket.TextMessage)
}

} else if (len(resp.Segments) == len(segments)) && vadState == StateSpeaking {
// We have the same number of segments, but we are still speaking
// We need to check if we are in this state for long enough, update the timer
// runVAD is a helper that calls your model's VAD method, returning
// true if it detects speech, false if it detects silence
func runVAD(ctx context.Context, session *Session, chunk []byte) ([]*proto.VADSegment, error) {

// Check if we have been listening for too long
if time.Since(timeListening) > sendToVADDelay {
vadState = StateTrailingSilence
} else {
adata := sound.BytesToInt16sLE(chunk)

timeListening = timeListening.Add(time.Since(timeListening))
}
} else {
log.Debug().Msg("VAD detected speech activity")
vadState = StateSpeaking
segments = resp.Segments
}
// Resample from 24kHz to 16kHz
// adata = sound.ResampleInt16(adata, 24000, 16000)

session.AudioBufferLock.Unlock()
} else {
session.AudioBufferLock.Unlock()
}
soundIntBuffer := &audio.IntBuffer{
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
}
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)

}
/* if len(adata) < 16000 {
log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer))
session.AudioBufferLock.Unlock()
continue
} */
float32Data := soundIntBuffer.AsFloat32Buffer().Data

resp, err := session.ModelInterface.VAD(ctx, &proto.VADRequest{
Audio: float32Data,
})
if err != nil {
return nil, err
}

// TODO: testing wav decoding
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
// buf, err := dec.FullPCMBuffer()
// if err != nil {
// //log.Error().Msgf("failed to process audio: %s", err.Error())
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
// session.AudioBufferLock.Unlock()
// continue
// }

//float32Data = buf.AsFloat32Buffer().Data

// If resp.Segments is empty => no speech
return resp.Segments, nil
}

// Function to generate a response based on the conversation
Expand Down

0 comments on commit 2d70ef6

Please sign in to comment.