refactor: consolidate channel send logic with context-safe handlers

Optimize channel operations by introducing reusable context-aware send functions (`send` and `sendErr`) across `wsrelay`, `handlers`, and `cliproxy`. Ensure graceful handling of canceled contexts during stream operations.
This commit is contained in:
Luis Pater
2026-01-28 10:58:35 +08:00
parent bbb55a8ab4
commit e93e05ae25
3 changed files with 65 additions and 10 deletions

View File

@@ -124,32 +124,47 @@ func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest)
out := make(chan StreamEvent) out := make(chan StreamEvent)
go func() { go func() {
defer close(out) defer close(out)
send := func(ev StreamEvent) bool {
if ctx == nil {
out <- ev
return true
}
select {
case <-ctx.Done():
return false
case out <- ev:
return true
}
}
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
out <- StreamEvent{Err: ctx.Err()}
return return
case msg, ok := <-respCh: case msg, ok := <-respCh:
if !ok { if !ok {
out <- StreamEvent{Err: errors.New("wsrelay: stream closed")} _ = send(StreamEvent{Err: errors.New("wsrelay: stream closed")})
return return
} }
switch msg.Type { switch msg.Type {
case MessageTypeStreamStart: case MessageTypeStreamStart:
resp := decodeResponse(msg.Payload) resp := decodeResponse(msg.Payload)
out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers} if okSend := send(StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}); !okSend {
return
}
case MessageTypeStreamChunk: case MessageTypeStreamChunk:
chunk := decodeChunk(msg.Payload) chunk := decodeChunk(msg.Payload)
out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk} if okSend := send(StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}); !okSend {
return
}
case MessageTypeStreamEnd: case MessageTypeStreamEnd:
out <- StreamEvent{Type: MessageTypeStreamEnd} _ = send(StreamEvent{Type: MessageTypeStreamEnd})
return return
case MessageTypeError: case MessageTypeError:
out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)} _ = send(StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)})
return return
case MessageTypeHTTPResp: case MessageTypeHTTPResp:
resp := decodeResponse(msg.Payload) resp := decodeResponse(msg.Payload)
out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body} _ = send(StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body})
return return
default: default:
} }

View File

@@ -506,6 +506,32 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
bootstrapRetries := 0 bootstrapRetries := 0
maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg) maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg)
sendErr := func(msg *interfaces.ErrorMessage) bool {
if ctx == nil {
errChan <- msg
return true
}
select {
case <-ctx.Done():
return false
case errChan <- msg:
return true
}
}
sendData := func(chunk []byte) bool {
if ctx == nil {
dataChan <- chunk
return true
}
select {
case <-ctx.Done():
return false
case dataChan <- chunk:
return true
}
}
bootstrapEligible := func(err error) bool { bootstrapEligible := func(err error) bool {
status := statusFromError(err) status := statusFromError(err)
if status == 0 { if status == 0 {
@@ -565,12 +591,14 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
addon = hdr.Clone() addon = hdr.Clone()
} }
} }
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon} _ = sendErr(&interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon})
return return
} }
if len(chunk.Payload) > 0 { if len(chunk.Payload) > 0 {
sentPayload = true sentPayload = true
dataChan <- cloneBytes(chunk.Payload) if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
return
}
} }
} }
} }

View File

@@ -718,6 +718,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) { go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out) defer close(out)
var failed bool var failed bool
forward := true
for chunk := range streamChunks { for chunk := range streamChunks {
if chunk.Err != nil && !failed { if chunk.Err != nil && !failed {
failed = true failed = true
@@ -728,7 +729,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
} }
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr}) m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
} }
out <- chunk if !forward {
continue
}
if streamCtx == nil {
out <- chunk
continue
}
select {
case <-streamCtx.Done():
forward = false
case out <- chunk:
}
} }
if !failed { if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})