mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
122 lines
2.9 KiB
Go
122 lines
2.9 KiB
Go
package handlers
|
|
|
|
import (
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
|
)
|
|
|
|
type StreamForwardOptions struct {
|
|
// KeepAliveInterval overrides the configured streaming keep-alive interval.
|
|
// If nil, the configured default is used. If set to <= 0, keep-alives are disabled.
|
|
KeepAliveInterval *time.Duration
|
|
|
|
// WriteChunk writes a single data chunk to the response body. It should not flush.
|
|
WriteChunk func(chunk []byte)
|
|
|
|
// WriteTerminalError writes an error payload to the response body when streaming fails
|
|
// after headers have already been committed. It should not flush.
|
|
WriteTerminalError func(errMsg *interfaces.ErrorMessage)
|
|
|
|
// WriteDone optionally writes a terminal marker when the upstream data channel closes
|
|
// without an error (e.g. OpenAI's `[DONE]`). It should not flush.
|
|
WriteDone func()
|
|
|
|
// WriteKeepAlive optionally writes a keep-alive heartbeat. It should not flush.
|
|
// When nil, a standard SSE comment heartbeat is used.
|
|
WriteKeepAlive func()
|
|
}
|
|
|
|
func (h *BaseAPIHandler) ForwardStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, opts StreamForwardOptions) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
if cancel == nil {
|
|
return
|
|
}
|
|
|
|
writeChunk := opts.WriteChunk
|
|
if writeChunk == nil {
|
|
writeChunk = func([]byte) {}
|
|
}
|
|
|
|
writeKeepAlive := opts.WriteKeepAlive
|
|
if writeKeepAlive == nil {
|
|
writeKeepAlive = func() {
|
|
_, _ = c.Writer.Write([]byte(": keep-alive\n\n"))
|
|
}
|
|
}
|
|
|
|
keepAliveInterval := StreamingKeepAliveInterval(h.Cfg)
|
|
if opts.KeepAliveInterval != nil {
|
|
keepAliveInterval = *opts.KeepAliveInterval
|
|
}
|
|
var keepAlive *time.Ticker
|
|
var keepAliveC <-chan time.Time
|
|
if keepAliveInterval > 0 {
|
|
keepAlive = time.NewTicker(keepAliveInterval)
|
|
defer keepAlive.Stop()
|
|
keepAliveC = keepAlive.C
|
|
}
|
|
|
|
var terminalErr *interfaces.ErrorMessage
|
|
for {
|
|
select {
|
|
case <-c.Request.Context().Done():
|
|
cancel(c.Request.Context().Err())
|
|
return
|
|
case chunk, ok := <-data:
|
|
if !ok {
|
|
// Prefer surfacing a terminal error if one is pending.
|
|
if terminalErr == nil {
|
|
select {
|
|
case errMsg, ok := <-errs:
|
|
if ok && errMsg != nil {
|
|
terminalErr = errMsg
|
|
}
|
|
default:
|
|
}
|
|
}
|
|
if terminalErr != nil {
|
|
if opts.WriteTerminalError != nil {
|
|
opts.WriteTerminalError(terminalErr)
|
|
}
|
|
flusher.Flush()
|
|
cancel(terminalErr.Error)
|
|
return
|
|
}
|
|
if opts.WriteDone != nil {
|
|
opts.WriteDone()
|
|
}
|
|
flusher.Flush()
|
|
cancel(nil)
|
|
return
|
|
}
|
|
writeChunk(chunk)
|
|
flusher.Flush()
|
|
case errMsg, ok := <-errs:
|
|
if !ok {
|
|
continue
|
|
}
|
|
if errMsg != nil {
|
|
terminalErr = errMsg
|
|
if opts.WriteTerminalError != nil {
|
|
opts.WriteTerminalError(errMsg)
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
var execErr error
|
|
if errMsg != nil {
|
|
execErr = errMsg.Error
|
|
}
|
|
cancel(execErr)
|
|
return
|
|
case <-keepAliveC:
|
|
writeKeepAlive()
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}
|