/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.
Unpooled;
import io.netty.channel.
Channel;
import io.netty.channel.
ChannelFutureListener;
import io.netty.channel.
ChannelHandler;
import io.netty.channel.
ChannelHandlerContext;
import io.netty.channel.
ChannelInboundHandler;
import io.netty.channel.
ChannelInboundHandlerAdapter;
import io.netty.channel.
ChannelPipeline;
import io.netty.handler.codec.http.
DefaultFullHttpResponse;
import io.netty.handler.codec.http.
FullHttpRequest;
import io.netty.handler.codec.http.
FullHttpResponse;
import io.netty.handler.codec.http.
HttpHeaders;
import io.netty.handler.codec.http.
HttpResponseStatus;
import io.netty.util.
AttributeKey;
import java.util.
List;
import static io.netty.handler.codec.http.
HttpVersion.*;
/**
* This handler does all the heavy lifting for you to run a websocket server.
*
* It takes care of websocket handshaking as well as processing of control frames (Close, Ping, Pong). Text and Binary
* data frames are passed to the next handler in the pipeline (implemented by you) for processing.
*
* See <tt>io.netty.example.http.websocketx.html5.WebSocketServer</tt> for usage.
*
* The implementation of this handler assumes that you just want to run a websocket server and not process other types
* HTTP requests (like GET and POST). If you wish to support both HTTP requests and websockets in the one server, refer
* to the <tt>io.netty.example.http.websocketx.server.WebSocketServer</tt> example.
*
* To know once a handshake was done you can intercept the
* {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} and check if the event was instance
* of {@link HandshakeComplete}, the event will contain extra information about the handshake such as the request and
* selected subprotocol.
*/
public class
WebSocketServerProtocolHandler extends
WebSocketProtocolHandler {
/**
* Events that are fired to notify about handshake status
*/
public enum
ServerHandshakeStateEvent {
/**
* The Handshake was completed successfully and the channel was upgraded to websockets.
*
* @deprecated in favor of {@link HandshakeComplete} class,
* it provides extra information about the handshake
*/
@
Deprecated
HANDSHAKE_COMPLETE
}
/**
* The Handshake was completed successfully and the channel was upgraded to websockets.
*/
public static final class
HandshakeComplete {
private final
String requestUri;
private final
HttpHeaders requestHeaders;
private final
String selectedSubprotocol;
HandshakeComplete(
String requestUri,
HttpHeaders requestHeaders,
String selectedSubprotocol) {
this.
requestUri =
requestUri;
this.
requestHeaders =
requestHeaders;
this.
selectedSubprotocol =
selectedSubprotocol;
}
public
String requestUri() {
return
requestUri;
}
public
HttpHeaders requestHeaders() {
return
requestHeaders;
}
public
String selectedSubprotocol() {
return
selectedSubprotocol;
}
}
private static final
AttributeKey<
WebSocketServerHandshaker>
HANDSHAKER_ATTR_KEY =
AttributeKey.
valueOf(
WebSocketServerHandshaker.class, "HANDSHAKER");
private final
String websocketPath;
private final
String subprotocols;
private final boolean
allowExtensions;
private final int
maxFramePayloadLength;
private final boolean
allowMaskMismatch;
private final boolean
checkStartsWith;
public
WebSocketServerProtocolHandler(
String websocketPath) {
this(
websocketPath, null, false);
}
public
WebSocketServerProtocolHandler(
String websocketPath, boolean
checkStartsWith) {
this(
websocketPath, null, false, 65536, false,
checkStartsWith);
}
public
WebSocketServerProtocolHandler(
String websocketPath,
String subprotocols) {
this(
websocketPath,
subprotocols, false);
}
public
WebSocketServerProtocolHandler(
String websocketPath,
String subprotocols, boolean
allowExtensions) {
this(
websocketPath,
subprotocols,
allowExtensions, 65536);
}
public
WebSocketServerProtocolHandler(
String websocketPath,
String subprotocols,
boolean
allowExtensions, int
maxFrameSize) {
this(
websocketPath,
subprotocols,
allowExtensions,
maxFrameSize, false);
}
public
WebSocketServerProtocolHandler(
String websocketPath,
String subprotocols,
boolean
allowExtensions, int
maxFrameSize, boolean
allowMaskMismatch) {
this(
websocketPath,
subprotocols,
allowExtensions,
maxFrameSize,
allowMaskMismatch, false);
}
public
WebSocketServerProtocolHandler(
String websocketPath,
String subprotocols,
boolean
allowExtensions, int
maxFrameSize, boolean
allowMaskMismatch, boolean
checkStartsWith) {
this(
websocketPath,
subprotocols,
allowExtensions,
maxFrameSize,
allowMaskMismatch,
checkStartsWith, true);
}
public
WebSocketServerProtocolHandler(
String websocketPath,
String subprotocols,
boolean
allowExtensions, int
maxFrameSize, boolean
allowMaskMismatch,
boolean
checkStartsWith, boolean
dropPongFrames) {
super(
dropPongFrames);
this.
websocketPath =
websocketPath;
this.
subprotocols =
subprotocols;
this.
allowExtensions =
allowExtensions;
maxFramePayloadLength =
maxFrameSize;
this.
allowMaskMismatch =
allowMaskMismatch;
this.
checkStartsWith =
checkStartsWith;
}
@
Override
public void
handlerAdded(
ChannelHandlerContext ctx) {
ChannelPipeline cp =
ctx.
pipeline();
if (
cp.
get(
WebSocketServerProtocolHandshakeHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one.
ctx.
pipeline().
addBefore(
ctx.
name(),
WebSocketServerProtocolHandshakeHandler.class.
getName(),
new
WebSocketServerProtocolHandshakeHandler(
websocketPath,
subprotocols,
allowExtensions,
maxFramePayloadLength,
allowMaskMismatch,
checkStartsWith));
}
if (
cp.
get(
Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
ctx.
pipeline().
addBefore(
ctx.
name(),
Utf8FrameValidator.class.
getName(),
new
Utf8FrameValidator());
}
}
@
Override
protected void
decode(
ChannelHandlerContext ctx,
WebSocketFrame frame,
List<
Object>
out) throws
Exception {
if (
frame instanceof
CloseWebSocketFrame) {
WebSocketServerHandshaker handshaker =
getHandshaker(
ctx.
channel());
if (
handshaker != null) {
frame.
retain();
handshaker.
close(
ctx.
channel(), (
CloseWebSocketFrame)
frame);
} else {
ctx.
writeAndFlush(
Unpooled.
EMPTY_BUFFER).
addListener(
ChannelFutureListener.
CLOSE);
}
return;
}
super.decode(
ctx,
frame,
out);
}
@
Override
public void
exceptionCaught(
ChannelHandlerContext ctx,
Throwable cause) throws
Exception {
if (
cause instanceof
WebSocketHandshakeException) {
FullHttpResponse response = new
DefaultFullHttpResponse(
HTTP_1_1,
HttpResponseStatus.
BAD_REQUEST,
Unpooled.
wrappedBuffer(
cause.
getMessage().
getBytes()));
ctx.
channel().
writeAndFlush(
response).
addListener(
ChannelFutureListener.
CLOSE);
} else {
ctx.
fireExceptionCaught(
cause);
ctx.
close();
}
}
static
WebSocketServerHandshaker getHandshaker(
Channel channel) {
return
channel.
attr(
HANDSHAKER_ATTR_KEY).
get();
}
static void
setHandshaker(
Channel channel,
WebSocketServerHandshaker handshaker) {
channel.
attr(
HANDSHAKER_ATTR_KEY).
set(
handshaker);
}
static
ChannelHandler forbiddenHttpRequestResponder() {
return new
ChannelInboundHandlerAdapter() {
@
Override
public void
channelRead(
ChannelHandlerContext ctx,
Object msg) throws
Exception {
if (
msg instanceof
FullHttpRequest) {
((
FullHttpRequest)
msg).
release();
FullHttpResponse response =
new
DefaultFullHttpResponse(
HTTP_1_1,
HttpResponseStatus.
FORBIDDEN);
ctx.
channel().
writeAndFlush(
response);
} else {
ctx.
fireChannelRead(
msg);
}
}
};
}
}