diff --git a/context.go b/context.go index 67e83181c..e7d9199b7 100644 --- a/context.go +++ b/context.go @@ -269,7 +269,8 @@ func (c *context) IsTLS() bool { func (c *context) IsWebSocket() bool { upgrade := c.request.Header.Get(HeaderUpgrade) - return strings.EqualFold(upgrade, "websocket") + connection := c.request.Header.Get(HeaderConnection) + return strings.EqualFold(upgrade, "websocket") && strings.Contains(strings.ToLower(connection), "upgrade") } func (c *context) Scheme() string { diff --git a/context_test.go b/context_test.go index 1fd89edb4..d5b4bb35d 100644 --- a/context_test.go +++ b/context_test.go @@ -969,7 +969,10 @@ func TestContext_IsWebSocket(t *testing.T) { { &context{ request: &http.Request{ - Header: http.Header{HeaderUpgrade: []string{"websocket"}}, + Header: http.Header{ + HeaderUpgrade: []string{"websocket"}, + HeaderConnection: []string{"upgrade"}, + }, }, }, assert.True, @@ -977,7 +980,10 @@ func TestContext_IsWebSocket(t *testing.T) { { &context{ request: &http.Request{ - Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, + Header: http.Header{ + HeaderUpgrade: []string{"Websocket"}, + HeaderConnection: []string{"Upgrade"}, + }, }, }, assert.True, @@ -996,6 +1002,17 @@ func TestContext_IsWebSocket(t *testing.T) { }, assert.False, }, + { + &context{ + request: &http.Request{ + Header: http.Header{ + HeaderUpgrade: []string{"websocket"}, + HeaderConnection: []string{"close"}, + }, + }, + }, + assert.False, + }, } for i, tt := range tests {