/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.channel.
Channel;
import io.netty.channel.
ChannelFuture;
import io.netty.channel.
ChannelFutureListener;
import io.netty.channel.
ChannelHandlerContext;
import io.netty.channel.
ChannelPipeline;
import io.netty.channel.
ChannelPromise;
import io.netty.channel.
SimpleChannelInboundHandler;
import io.netty.handler.codec.http.
FullHttpRequest;
import io.netty.handler.codec.http.
FullHttpResponse;
import io.netty.handler.codec.http.
HttpClientCodec;
import io.netty.handler.codec.http.
HttpContentDecompressor;
import io.netty.handler.codec.http.
HttpHeaderNames;
import io.netty.handler.codec.http.
HttpHeaders;
import io.netty.handler.codec.http.
HttpObjectAggregator;
import io.netty.handler.codec.http.
HttpRequestEncoder;
import io.netty.handler.codec.http.
HttpResponse;
import io.netty.handler.codec.http.
HttpResponseDecoder;
import io.netty.handler.codec.http.
HttpScheme;
import io.netty.util.
NetUtil;
import io.netty.util.
ReferenceCountUtil;
import io.netty.util.internal.
ThrowableUtil;
import java.net.
URI;
import java.nio.channels.
ClosedChannelException;
import java.util.
Locale;
/**
* Base class for web socket client handshake implementations
*/
public abstract class
WebSocketClientHandshaker {
private static final
ClosedChannelException CLOSED_CHANNEL_EXCEPTION =
ThrowableUtil.
unknownStackTrace(
new
ClosedChannelException(),
WebSocketClientHandshaker.class, "processHandshake(...)");
private static final
String HTTP_SCHEME_PREFIX =
HttpScheme.
HTTP + "://";
private static final
String HTTPS_SCHEME_PREFIX =
HttpScheme.
HTTPS + "://";
private final
URI uri;
private final
WebSocketVersion version;
private volatile boolean
handshakeComplete;
private final
String expectedSubprotocol;
private volatile
String actualSubprotocol;
protected final
HttpHeaders customHeaders;
private final int
maxFramePayloadLength;
/**
* Base constructor
*
* @param uri
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
*/
protected
WebSocketClientHandshaker(
URI uri,
WebSocketVersion version,
String subprotocol,
HttpHeaders customHeaders, int
maxFramePayloadLength) {
this.
uri =
uri;
this.
version =
version;
expectedSubprotocol =
subprotocol;
this.
customHeaders =
customHeaders;
this.
maxFramePayloadLength =
maxFramePayloadLength;
}
/**
* Returns the URI to the web socket. e.g. "ws://myhost.com/path"
*/
public
URI uri() {
return
uri;
}
/**
* Version of the web socket specification that is being used
*/
public
WebSocketVersion version() {
return
version;
}
/**
* Returns the max length for any frame's payload
*/
public int
maxFramePayloadLength() {
return
maxFramePayloadLength;
}
/**
* Flag to indicate if the opening handshake is complete
*/
public boolean
isHandshakeComplete() {
return
handshakeComplete;
}
private void
setHandshakeComplete() {
handshakeComplete = true;
}
/**
* Returns the CSV of requested subprotocol(s) sent to the server as specified in the constructor
*/
public
String expectedSubprotocol() {
return
expectedSubprotocol;
}
/**
* Returns the subprotocol response sent by the server. Only available after end of handshake.
* Null if no subprotocol was requested or confirmed by the server.
*/
public
String actualSubprotocol() {
return
actualSubprotocol;
}
private void
setActualSubprotocol(
String actualSubprotocol) {
this.
actualSubprotocol =
actualSubprotocol;
}
/**
* Begins the opening handshake
*
* @param channel
* Channel
*/
public
ChannelFuture handshake(
Channel channel) {
if (
channel == null) {
throw new
NullPointerException("channel");
}
return
handshake(
channel,
channel.
newPromise());
}
/**
* Begins the opening handshake
*
* @param channel
* Channel
* @param promise
* the {@link ChannelPromise} to be notified when the opening handshake is sent
*/
public final
ChannelFuture handshake(
Channel channel, final
ChannelPromise promise) {
FullHttpRequest request =
newHandshakeRequest();
HttpResponseDecoder decoder =
channel.
pipeline().
get(
HttpResponseDecoder.class);
if (
decoder == null) {
HttpClientCodec codec =
channel.
pipeline().
get(
HttpClientCodec.class);
if (
codec == null) {
promise.
setFailure(new
IllegalStateException("ChannelPipeline does not contain " +
"a HttpResponseDecoder or HttpClientCodec"));
return
promise;
}
}
channel.
writeAndFlush(
request).
addListener(new
ChannelFutureListener() {
@
Override
public void
operationComplete(
ChannelFuture future) {
if (
future.
isSuccess()) {
ChannelPipeline p =
future.
channel().
pipeline();
ChannelHandlerContext ctx =
p.
context(
HttpRequestEncoder.class);
if (
ctx == null) {
ctx =
p.
context(
HttpClientCodec.class);
}
if (
ctx == null) {
promise.
setFailure(new
IllegalStateException("ChannelPipeline does not contain " +
"a HttpRequestEncoder or HttpClientCodec"));
return;
}
p.
addAfter(
ctx.
name(), "ws-encoder",
newWebSocketEncoder());
promise.
setSuccess();
} else {
promise.
setFailure(
future.
cause());
}
}
});
return
promise;
}
/**
* Returns a new {@link FullHttpRequest) which will be used for the handshake.
*/
protected abstract
FullHttpRequest newHandshakeRequest();
/**
* Validates and finishes the opening handshake initiated by {@link #handshake}}.
*
* @param channel
* Channel
* @param response
* HTTP response containing the closing handshake details
*/
public final void
finishHandshake(
Channel channel,
FullHttpResponse response) {
verify(
response);
// Verify the subprotocol that we received from the server.
// This must be one of our expected subprotocols - or null/empty if we didn't want to speak a subprotocol
String receivedProtocol =
response.
headers().
get(
HttpHeaderNames.
SEC_WEBSOCKET_PROTOCOL);
receivedProtocol =
receivedProtocol != null ?
receivedProtocol.
trim() : null;
String expectedProtocol =
expectedSubprotocol != null ?
expectedSubprotocol : "";
boolean
protocolValid = false;
if (
expectedProtocol.
isEmpty() &&
receivedProtocol == null) {
// No subprotocol required and none received
protocolValid = true;
setActualSubprotocol(
expectedSubprotocol); // null or "" - we echo what the user requested
} else if (!
expectedProtocol.
isEmpty() &&
receivedProtocol != null && !
receivedProtocol.
isEmpty()) {
// We require a subprotocol and received one -> verify it
for (
String protocol :
expectedProtocol.
split(",")) {
if (
protocol.
trim().
equals(
receivedProtocol)) {
protocolValid = true;
setActualSubprotocol(
receivedProtocol);
break;
}
}
} // else mixed cases - which are all errors
if (!
protocolValid) {
throw new
WebSocketHandshakeException(
String.
format(
"Invalid subprotocol. Actual: %s. Expected one of: %s",
receivedProtocol,
expectedSubprotocol));
}
setHandshakeComplete();
final
ChannelPipeline p =
channel.
pipeline();
// Remove decompressor from pipeline if its in use
HttpContentDecompressor decompressor =
p.
get(
HttpContentDecompressor.class);
if (
decompressor != null) {
p.
remove(
decompressor);
}
// Remove aggregator if present before
HttpObjectAggregator aggregator =
p.
get(
HttpObjectAggregator.class);
if (
aggregator != null) {
p.
remove(
aggregator);
}
ChannelHandlerContext ctx =
p.
context(
HttpResponseDecoder.class);
if (
ctx == null) {
ctx =
p.
context(
HttpClientCodec.class);
if (
ctx == null) {
throw new
IllegalStateException("ChannelPipeline does not contain " +
"a HttpRequestEncoder or HttpClientCodec");
}
final
HttpClientCodec codec = (
HttpClientCodec)
ctx.
handler();
// Remove the encoder part of the codec as the user may start writing frames after this method returns.
codec.
removeOutboundHandler();
p.
addAfter(
ctx.
name(), "ws-decoder",
newWebsocketDecoder());
// Delay the removal of the decoder so the user can setup the pipeline if needed to handle
// WebSocketFrame messages.
// See https://github.com/netty/netty/issues/4533
channel.
eventLoop().
execute(new
Runnable() {
@
Override
public void
run() {
p.
remove(
codec);
}
});
} else {
if (
p.
get(
HttpRequestEncoder.class) != null) {
// Remove the encoder part of the codec as the user may start writing frames after this method returns.
p.
remove(
HttpRequestEncoder.class);
}
final
ChannelHandlerContext context =
ctx;
p.
addAfter(
context.
name(), "ws-decoder",
newWebsocketDecoder());
// Delay the removal of the decoder so the user can setup the pipeline if needed to handle
// WebSocketFrame messages.
// See https://github.com/netty/netty/issues/4533
channel.
eventLoop().
execute(new
Runnable() {
@
Override
public void
run() {
p.
remove(
context.
handler());
}
});
}
}
/**
* Process the opening handshake initiated by {@link #handshake}}.
*
* @param channel
* Channel
* @param response
* HTTP response containing the closing handshake details
* @return future
* the {@link ChannelFuture} which is notified once the handshake completes.
*/
public final
ChannelFuture processHandshake(final
Channel channel,
HttpResponse response) {
return
processHandshake(
channel,
response,
channel.
newPromise());
}
/**
* Process the opening handshake initiated by {@link #handshake}}.
*
* @param channel
* Channel
* @param response
* HTTP response containing the closing handshake details
* @param promise
* the {@link ChannelPromise} to notify once the handshake completes.
* @return future
* the {@link ChannelFuture} which is notified once the handshake completes.
*/
public final
ChannelFuture processHandshake(final
Channel channel,
HttpResponse response,
final
ChannelPromise promise) {
if (
response instanceof
FullHttpResponse) {
try {
finishHandshake(
channel, (
FullHttpResponse)
response);
promise.
setSuccess();
} catch (
Throwable cause) {
promise.
setFailure(
cause);
}
} else {
ChannelPipeline p =
channel.
pipeline();
ChannelHandlerContext ctx =
p.
context(
HttpResponseDecoder.class);
if (
ctx == null) {
ctx =
p.
context(
HttpClientCodec.class);
if (
ctx == null) {
return
promise.
setFailure(new
IllegalStateException("ChannelPipeline does not contain " +
"a HttpResponseDecoder or HttpClientCodec"));
}
}
// Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more
// then enough for the websockets handshake payload.
//
// TODO: Make handshake work without HttpObjectAggregator at all.
String aggregatorName = "httpAggregator";
p.
addAfter(
ctx.
name(),
aggregatorName, new
HttpObjectAggregator(8192));
p.
addAfter(
aggregatorName, "handshaker", new
SimpleChannelInboundHandler<
FullHttpResponse>() {
@
Override
protected void
channelRead0(
ChannelHandlerContext ctx,
FullHttpResponse msg) throws
Exception {
// Remove ourself and do the actual handshake
ctx.
pipeline().
remove(this);
try {
finishHandshake(
channel,
msg);
promise.
setSuccess();
} catch (
Throwable cause) {
promise.
setFailure(
cause);
}
}
@
Override
public void
exceptionCaught(
ChannelHandlerContext ctx,
Throwable cause) throws
Exception {
// Remove ourself and fail the handshake promise.
ctx.
pipeline().
remove(this);
promise.
setFailure(
cause);
}
@
Override
public void
channelInactive(
ChannelHandlerContext ctx) throws
Exception {
// Fail promise if Channel was closed
promise.
tryFailure(
CLOSED_CHANNEL_EXCEPTION);
ctx.
fireChannelInactive();
}
});
try {
ctx.
fireChannelRead(
ReferenceCountUtil.
retain(
response));
} catch (
Throwable cause) {
promise.
setFailure(
cause);
}
}
return
promise;
}
/**
* Verify the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong.
*/
protected abstract void
verify(
FullHttpResponse response);
/**
* Returns the decoder to use after handshake is complete.
*/
protected abstract
WebSocketFrameDecoder newWebsocketDecoder();
/**
* Returns the encoder to use after the handshake is complete.
*/
protected abstract
WebSocketFrameEncoder newWebSocketEncoder();
/**
* Performs the closing handshake
*
* @param channel
* Channel
* @param frame
* Closing Frame that was received
*/
public
ChannelFuture close(
Channel channel,
CloseWebSocketFrame frame) {
if (
channel == null) {
throw new
NullPointerException("channel");
}
return
close(
channel,
frame,
channel.
newPromise());
}
/**
* Performs the closing handshake
*
* @param channel
* Channel
* @param frame
* Closing Frame that was received
* @param promise
* the {@link ChannelPromise} to be notified when the closing handshake is done
*/
public
ChannelFuture close(
Channel channel,
CloseWebSocketFrame frame,
ChannelPromise promise) {
if (
channel == null) {
throw new
NullPointerException("channel");
}
return
channel.
writeAndFlush(
frame,
promise);
}
/**
* Return the constructed raw path for the give {@link URI}.
*/
static
String rawPath(
URI wsURL) {
String path =
wsURL.
getRawPath();
String query =
wsURL.
getRawQuery();
if (
query != null && !
query.
isEmpty()) {
path =
path + '?' +
query;
}
return
path == null ||
path.
isEmpty() ? "/" :
path;
}
static
CharSequence websocketHostValue(
URI wsURL) {
int
port =
wsURL.
getPort();
if (
port == -1) {
return
wsURL.
getHost();
}
String host =
wsURL.
getHost();
if (
port ==
HttpScheme.
HTTP.
port()) {
return
HttpScheme.
HTTP.
name().
contentEquals(
wsURL.
getScheme())
||
WebSocketScheme.
WS.
name().
contentEquals(
wsURL.
getScheme()) ?
host :
NetUtil.
toSocketAddressString(
host,
port);
}
if (
port ==
HttpScheme.
HTTPS.
port()) {
return
HttpScheme.
HTTPS.
name().
contentEquals(
wsURL.
getScheme())
||
WebSocketScheme.
WSS.
name().
contentEquals(
wsURL.
getScheme()) ?
host :
NetUtil.
toSocketAddressString(
host,
port);
}
// if the port is not standard (80/443) its needed to add the port to the header.
// See http://tools.ietf.org/html/rfc6454#section-6.2
return
NetUtil.
toSocketAddressString(
host,
port);
}
static
CharSequence websocketOriginValue(
URI wsURL) {
String scheme =
wsURL.
getScheme();
final
String schemePrefix;
int
port =
wsURL.
getPort();
final int
defaultPort;
if (
WebSocketScheme.
WSS.
name().
contentEquals(
scheme)
||
HttpScheme.
HTTPS.
name().
contentEquals(
scheme)
|| (
scheme == null &&
port ==
WebSocketScheme.
WSS.
port())) {
schemePrefix =
HTTPS_SCHEME_PREFIX;
defaultPort =
WebSocketScheme.
WSS.
port();
} else {
schemePrefix =
HTTP_SCHEME_PREFIX;
defaultPort =
WebSocketScheme.
WS.
port();
}
// Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI")
String host =
wsURL.
getHost().
toLowerCase(
Locale.
US);
if (
port !=
defaultPort &&
port != -1) {
// if the port is not standard (80/443) its needed to add the port to the header.
// See http://tools.ietf.org/html/rfc6454#section-6.2
return
schemePrefix +
NetUtil.
toSocketAddressString(
host,
port);
}
return
schemePrefix +
host;
}
}