/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.channel.
ChannelFuture;
import io.netty.channel.
ChannelFutureListener;
import io.netty.channel.
ChannelHandlerContext;
import io.netty.channel.
ChannelInboundHandlerAdapter;
import io.netty.channel.
ChannelPipeline;
import io.netty.handler.codec.http.
DefaultFullHttpResponse;
import io.netty.handler.codec.http.
FullHttpRequest;
import io.netty.handler.codec.http.
HttpHeaderNames;
import io.netty.handler.codec.http.
HttpRequest;
import io.netty.handler.codec.http.
HttpResponse;
import io.netty.handler.ssl.
SslHandler;
import static io.netty.handler.codec.http.
HttpUtil.*;
import static io.netty.handler.codec.http.
HttpMethod.*;
import static io.netty.handler.codec.http.
HttpResponseStatus.*;
import static io.netty.handler.codec.http.
HttpVersion.*;
/**
* Handles the HTTP handshake (the HTTP Upgrade request) for {@link WebSocketServerProtocolHandler}.
*/
class
WebSocketServerProtocolHandshakeHandler extends
ChannelInboundHandlerAdapter {
private final
String websocketPath;
private final
String subprotocols;
private final boolean
allowExtensions;
private final int
maxFramePayloadSize;
private final boolean
allowMaskMismatch;
private final boolean
checkStartsWith;
WebSocketServerProtocolHandshakeHandler(
String websocketPath,
String subprotocols,
boolean
allowExtensions, int
maxFrameSize, boolean
allowMaskMismatch) {
this(
websocketPath,
subprotocols,
allowExtensions,
maxFrameSize,
allowMaskMismatch, false);
}
WebSocketServerProtocolHandshakeHandler(
String websocketPath,
String subprotocols,
boolean
allowExtensions, int
maxFrameSize, boolean
allowMaskMismatch, boolean
checkStartsWith) {
this.
websocketPath =
websocketPath;
this.
subprotocols =
subprotocols;
this.
allowExtensions =
allowExtensions;
maxFramePayloadSize =
maxFrameSize;
this.
allowMaskMismatch =
allowMaskMismatch;
this.
checkStartsWith =
checkStartsWith;
}
@
Override
public void
channelRead(final
ChannelHandlerContext ctx,
Object msg) throws
Exception {
final
FullHttpRequest req = (
FullHttpRequest)
msg;
if (
isNotWebSocketPath(
req)) {
ctx.
fireChannelRead(
msg);
return;
}
try {
if (
req.
method() !=
GET) {
sendHttpResponse(
ctx,
req, new
DefaultFullHttpResponse(
HTTP_1_1,
FORBIDDEN));
return;
}
final
WebSocketServerHandshakerFactory wsFactory = new
WebSocketServerHandshakerFactory(
getWebSocketLocation(
ctx.
pipeline(),
req,
websocketPath),
subprotocols,
allowExtensions,
maxFramePayloadSize,
allowMaskMismatch);
final
WebSocketServerHandshaker handshaker =
wsFactory.
newHandshaker(
req);
if (
handshaker == null) {
WebSocketServerHandshakerFactory.
sendUnsupportedVersionResponse(
ctx.
channel());
} else {
final
ChannelFuture handshakeFuture =
handshaker.
handshake(
ctx.
channel(),
req);
handshakeFuture.
addListener(new
ChannelFutureListener() {
@
Override
public void
operationComplete(
ChannelFuture future) throws
Exception {
if (!
future.
isSuccess()) {
ctx.
fireExceptionCaught(
future.
cause());
} else {
// Kept for compatibility
ctx.
fireUserEventTriggered(
WebSocketServerProtocolHandler.
ServerHandshakeStateEvent.
HANDSHAKE_COMPLETE);
ctx.
fireUserEventTriggered(
new
WebSocketServerProtocolHandler.
HandshakeComplete(
req.
uri(),
req.
headers(),
handshaker.
selectedSubprotocol()));
}
}
});
WebSocketServerProtocolHandler.
setHandshaker(
ctx.
channel(),
handshaker);
ctx.
pipeline().
replace(this, "WS403Responder",
WebSocketServerProtocolHandler.
forbiddenHttpRequestResponder());
}
} finally {
req.
release();
}
}
private boolean
isNotWebSocketPath(
FullHttpRequest req) {
return
checkStartsWith ? !
req.
uri().
startsWith(
websocketPath) : !
req.
uri().
equals(
websocketPath);
}
private static void
sendHttpResponse(
ChannelHandlerContext ctx,
HttpRequest req,
HttpResponse res) {
ChannelFuture f =
ctx.
channel().
writeAndFlush(
res);
if (!
isKeepAlive(
req) ||
res.
status().
code() != 200) {
f.
addListener(
ChannelFutureListener.
CLOSE);
}
}
private static
String getWebSocketLocation(
ChannelPipeline cp,
HttpRequest req,
String path) {
String protocol = "ws";
if (
cp.
get(
SslHandler.class) != null) {
// SSL in use so use Secure WebSockets
protocol = "wss";
}
String host =
req.
headers().
get(
HttpHeaderNames.
HOST);
return
protocol + "://" +
host +
path;
}
}