From 234a6d4c00cb77af9852aca0b8289745d5529b4b Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Fri, 26 Sep 2025 08:13:39 +0800 Subject: [PATCH] fix(response): refine hijack behavior for response lifecycle (#4373) * feat: refine hijack behavior for response lifecycle and add tests - Clarify the error message for attempted hijack after response body data is written - Modify hijack behavior: allow hijacking after headers are written (for better websocket compatibility), but block hijacking after any body data is sent - Add comprehensive tests to validate allowed hijack after header write and disallowed hijack after body write fix https://github.com/gin-gonic/gin/issues/4372 Signed-off-by: appleboy * test: use require for immediate test failure on errors - Replace assert with require for error checks to ensure test failures immediately halt execution Signed-off-by: appleboy * Update response_writer.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Signed-off-by: appleboy Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- response_writer.go | 6 +++-- response_writer_test.go | 58 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/response_writer.go b/response_writer.go index ab2f5fec..6907e514 100644 --- a/response_writer.go +++ b/response_writer.go @@ -17,7 +17,7 @@ const ( defaultStatus = http.StatusOK ) -var errHijackAlreadyWritten = errors.New("gin: response already written") +var errHijackAlreadyWritten = errors.New("gin: response body already written") // ResponseWriter ... type ResponseWriter interface { @@ -109,7 +109,9 @@ func (w *responseWriter) Written() bool { // Hijack implements the http.Hijacker interface. func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if w.Written() { + // Allow hijacking before any data is written (size == -1) or after headers are written (size == 0), + // but not after body data is written (size > 0). For compatibility with websocket libraries (e.g., github.com/coder/websocket) + if w.size > 0 { return nil, nil, errHijackAlreadyWritten } if w.size < 0 { diff --git a/response_writer_test.go b/response_writer_test.go index ef198418..dfc1d2c6 100644 --- a/response_writer_test.go +++ b/response_writer_test.go @@ -194,6 +194,64 @@ func TestResponseWriterHijackAfterWrite(t *testing.T) { } } +// Test: WebSocket compatibility - allow hijack after WriteHeaderNow(), but block after body data. +func TestResponseWriterHijackAfterWriteHeaderNow(t *testing.T) { + tests := []struct { + name string + action func(w ResponseWriter) error + expectWrittenBeforeHijack bool + expectHijackSuccess bool + expectWrittenAfterHijack bool + expectError error + }{ + { + name: "hijack after WriteHeaderNow only should succeed (websocket pattern)", + action: func(w ResponseWriter) error { + w.WriteHeaderNow() // Simulate websocket.Accept() behavior + return nil + }, + expectWrittenBeforeHijack: true, + expectHijackSuccess: true, // NEW BEHAVIOR: allow hijack after just header write + expectWrittenAfterHijack: true, + expectError: nil, + }, + { + name: "hijack after WriteHeaderNow + Write should fail", + action: func(w ResponseWriter) error { + w.WriteHeaderNow() + _, err := w.Write([]byte("test")) + return err + }, + expectWrittenBeforeHijack: true, + expectHijackSuccess: false, + expectWrittenAfterHijack: true, + expectError: errHijackAlreadyWritten, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + hijacker := &mockHijacker{ResponseRecorder: httptest.NewRecorder()} + writer := &responseWriter{} + writer.reset(hijacker) + w := ResponseWriter(writer) + + require.NoError(t, tc.action(w), "unexpected error during pre-hijack action") + + assert.Equal(t, tc.expectWrittenBeforeHijack, w.Written(), "unexpected w.Written() state before hijack") + + _, _, hijackErr := w.Hijack() + + if tc.expectError == nil { + require.NoError(t, hijackErr, "expected hijack to succeed") + } else { + require.ErrorIs(t, hijackErr, tc.expectError, "unexpected error from Hijack()") + } + assert.Equal(t, tc.expectHijackSuccess, hijacker.hijacked, "unexpected hijacker.hijacked state") + assert.Equal(t, tc.expectWrittenAfterHijack, w.Written(), "unexpected w.Written() state after hijack") + }) + } +} + func TestResponseWriterFlush(t *testing.T) { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writer := &responseWriter{}