/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.server.protocol.http;
import java.io.
IOException;
import java.nio.
ByteBuffer;
import java.util.
ArrayList;
import java.util.
Collections;
import java.util.
HashMap;
import java.util.
List;
import java.util.
Map;
import java.util.concurrent.
CompletableFuture;
import java.util.function.
Function;
import javax.net.ssl.
SSLEngine;
import org.xnio.
ChannelListener;
import org.xnio.
IoUtils;
import org.xnio.
OptionMap;
import org.xnio.
Pool;
import org.xnio.
StreamConnection;
import org.xnio.channels.
StreamSourceChannel;
import org.xnio.ssl.
SslConnection;
import io.undertow.
UndertowLogger;
import io.undertow.
UndertowMessages;
import io.undertow.
UndertowOptions;
import io.undertow.connector.
ByteBufferPool;
import io.undertow.connector.
PooledByteBuffer;
import io.undertow.protocols.alpn.
ALPNManager;
import io.undertow.protocols.alpn.
ALPNProvider;
import io.undertow.protocols.ssl.
SslConduit;
import io.undertow.protocols.ssl.
UndertowXnioSsl;
import io.undertow.server.
AggregateConnectorStatistics;
import io.undertow.server.
ConnectorStatistics;
import io.undertow.server.
DelegateOpenListener;
import io.undertow.server.
HttpHandler;
import io.undertow.server.
OpenListener;
import io.undertow.server.
XnioByteBufferPool;
/**
* Open listener adaptor for ALPN connections
* <p>
* Not a proper open listener as such, but more a mechanism for selecting between them.
*
* @author Stuart Douglas
*/
public class
AlpnOpenListener implements
ChannelListener<
StreamConnection>,
OpenListener {
/**
* HTTP/2 required cipher. Not strictly part of ALPN but it can live here for now till we have a better solution.
*/
public static final
String REQUIRED_CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
public static final
String REQUIRED_PROTOCOL = "TLSv1.2";
private final
ALPNManager alpnManager =
ALPNManager.
INSTANCE; //todo: configurable
private final
ByteBufferPool bufferPool;
private final
Map<
String,
ListenerEntry>
listeners = new
HashMap<>();
private
String[]
protocols;
private final
String fallbackProtocol;
private volatile
HttpHandler rootHandler;
private volatile
OptionMap undertowOptions;
private volatile boolean
statisticsEnabled;
private volatile boolean
providerLogged;
private volatile boolean
alpnFailLogged;
public
AlpnOpenListener(
Pool<
ByteBuffer>
bufferPool,
OptionMap undertowOptions,
DelegateOpenListener httpListener) {
this(
bufferPool,
undertowOptions, "http/1.1",
httpListener);
}
public
AlpnOpenListener(
Pool<
ByteBuffer>
bufferPool,
OptionMap undertowOptions) {
this(
bufferPool,
undertowOptions, null, null);
}
public
AlpnOpenListener(
Pool<
ByteBuffer>
bufferPool,
OptionMap undertowOptions,
String fallbackProtocol,
DelegateOpenListener fallbackListener) {
this(new
XnioByteBufferPool(
bufferPool),
undertowOptions,
fallbackProtocol,
fallbackListener);
}
public
AlpnOpenListener(
ByteBufferPool bufferPool,
OptionMap undertowOptions,
DelegateOpenListener httpListener) {
this(
bufferPool,
undertowOptions, "http/1.1",
httpListener);
}
public
AlpnOpenListener(
ByteBufferPool bufferPool) {
this(
bufferPool,
OptionMap.
EMPTY, null, null);
}
public
AlpnOpenListener(
ByteBufferPool bufferPool,
OptionMap undertowOptions) {
this(
bufferPool,
undertowOptions, null, null);
}
public
AlpnOpenListener(
ByteBufferPool bufferPool,
OptionMap undertowOptions,
String fallbackProtocol,
DelegateOpenListener fallbackListener) {
this.
bufferPool =
bufferPool;
this.
undertowOptions =
undertowOptions;
this.
fallbackProtocol =
fallbackProtocol;
statisticsEnabled =
undertowOptions.
get(
UndertowOptions.
ENABLE_CONNECTOR_STATISTICS, false);
if (
fallbackProtocol != null &&
fallbackListener != null) {
addProtocol(
fallbackProtocol,
fallbackListener, 0);
}
}
@
Override
public
HttpHandler getRootHandler() {
return
rootHandler;
}
@
Override
public void
setRootHandler(
HttpHandler rootHandler) {
this.
rootHandler =
rootHandler;
for (
Map.
Entry<
String,
ListenerEntry>
delegate :
listeners.
entrySet()) {
delegate.
getValue().
listener.
setRootHandler(
rootHandler);
}
}
@
Override
public
OptionMap getUndertowOptions() {
return
undertowOptions;
}
@
Override
public void
setUndertowOptions(
OptionMap undertowOptions) {
if (
undertowOptions == null) {
throw
UndertowMessages.
MESSAGES.
argumentCannotBeNull("undertowOptions");
}
this.
undertowOptions =
undertowOptions;
for (
Map.
Entry<
String,
ListenerEntry>
delegate :
listeners.
entrySet()) {
delegate.
getValue().
listener.
setRootHandler(
rootHandler);
}
statisticsEnabled =
undertowOptions.
get(
UndertowOptions.
ENABLE_CONNECTOR_STATISTICS, false);
}
@
Override
public
ByteBufferPool getBufferPool() {
return
bufferPool;
}
@
Override
public
ConnectorStatistics getConnectorStatistics() {
if (
statisticsEnabled) {
List<
ConnectorStatistics>
stats = new
ArrayList<>();
for (
Map.
Entry<
String,
ListenerEntry>
l :
listeners.
entrySet()) {
ConnectorStatistics c =
l.
getValue().
listener.
getConnectorStatistics();
if (
c != null) {
stats.
add(
c);
}
}
return new
AggregateConnectorStatistics(
stats.
toArray(new
ConnectorStatistics[
stats.
size()]));
}
return null;
}
private static class
ListenerEntry implements
Comparable<
ListenerEntry> {
final
DelegateOpenListener listener;
final int
weight;
final
String protocol;
ListenerEntry(
DelegateOpenListener listener, int
weight,
String protocol) {
this.
listener =
listener;
this.
weight =
weight;
this.
protocol =
protocol;
}
@
Override
public boolean
equals(
Object o) {
if (this ==
o) return true;
if (!(
o instanceof
ListenerEntry)) return false;
ListenerEntry that = (
ListenerEntry)
o;
if (
weight !=
that.
weight) return false;
if (!
listener.
equals(
that.
listener)) return false;
return
protocol.
equals(
that.
protocol);
}
@
Override
public int
hashCode() {
int
result =
listener.
hashCode();
result = 31 *
result +
weight;
result = 31 *
result +
protocol.
hashCode();
return
result;
}
@
Override
public int
compareTo(
ListenerEntry o) {
return -
Integer.
compare(this.
weight,
o.
weight);
}
}
public
AlpnOpenListener addProtocol(
String name,
DelegateOpenListener listener, int
weight) {
listeners.
put(
name, new
ListenerEntry(
listener,
weight,
name));
List<
ListenerEntry>
list = new
ArrayList<>(
listeners.
values());
Collections.
sort(
list);
protocols = new
String[
list.
size()];
for (int
i = 0;
i <
list.
size(); ++
i) {
protocols[
i] =
list.
get(
i).
protocol;
}
return this;
}
public void
handleEvent(final
StreamConnection channel) {
if (
UndertowLogger.
REQUEST_LOGGER.
isTraceEnabled()) {
UndertowLogger.
REQUEST_LOGGER.
tracef("Opened connection with %s",
channel.
getPeerAddress());
}
final
SslConduit sslConduit =
UndertowXnioSsl.
getSslConduit((
SslConnection)
channel);
final
SSLEngine originalSSlEngine =
sslConduit.
getSSLEngine();
//this will end up with the ALPN engine, or null if the engine did not support ALPN
final
CompletableFuture<
SelectedAlpn>
selectedALPNEngine = new
CompletableFuture<>();
alpnManager.
registerEngineCallback(
originalSSlEngine, new
SSLConduitUpdater(
sslConduit, new
Function<
SSLEngine,
SSLEngine>() {
@
Override
public
SSLEngine apply(
SSLEngine engine) {
if (!
engineSupportsHTTP2(
engine)) {
if (!
alpnFailLogged) {
synchronized (this) {
if (!
alpnFailLogged) {
UndertowLogger.
REQUEST_LOGGER.
debugf("ALPN has been configured however %s is not present or TLS1.2 is not enabled, falling back to default protocol",
REQUIRED_CIPHER);
alpnFailLogged = true;
}
}
}
if (
fallbackProtocol != null) {
ListenerEntry listener =
listeners.
get(
fallbackProtocol);
if (
listener != null) {
listener.
listener.
handleEvent(
channel);
selectedALPNEngine.
complete(null);
return
engine;
}
}
}
final
ALPNProvider provider =
alpnManager.
getProvider(
engine);
if (
provider == null) {
if (!
providerLogged) {
synchronized (this) {
if (!
providerLogged) {
UndertowLogger.
REQUEST_LOGGER.
debugf("ALPN has been configured however no provider could be found for engine %s for connector at %s",
engine,
channel.
getLocalAddress());
providerLogged = true;
}
}
}
if (
fallbackProtocol != null) {
ListenerEntry listener =
listeners.
get(
fallbackProtocol);
if (
listener != null) {
listener.
listener.
handleEvent(
channel);
selectedALPNEngine.
complete(null);
return
engine;
}
}
UndertowLogger.
REQUEST_LOGGER.
debugf("No ALPN provider available and no fallback defined");
IoUtils.
safeClose(
channel);
selectedALPNEngine.
complete(null);
return
engine;
}
if (!
providerLogged) {
synchronized (this) {
if (!
providerLogged) {
UndertowLogger.
REQUEST_LOGGER.
debugf("Using ALPN provider %s for connector at %s",
provider,
channel.
getLocalAddress());
providerLogged = true;
}
}
}
final
SSLEngine newEngine =
provider.
setProtocols(
engine,
protocols);
ALPNLimitingSSLEngine alpnLimitingSSLEngine = new
ALPNLimitingSSLEngine(
newEngine, new
Runnable() {
@
Override
public void
run() {
provider.
setProtocols(
newEngine, new
String[]{
fallbackProtocol});
}
});
selectedALPNEngine.
complete(new
SelectedAlpn(
newEngine,
provider)); //we don't want the limiting engine, but the actual one we can use with a provider
return
alpnLimitingSSLEngine;
}
}));
final
AlpnConnectionListener potentialConnection = new
AlpnConnectionListener(
channel,
selectedALPNEngine);
channel.
getSourceChannel().
setReadListener(
potentialConnection);
potentialConnection.
handleEvent(
channel.
getSourceChannel());
}
public static boolean
engineSupportsHTTP2(
SSLEngine engine) {
//check to make sure the engine meets the minimum requirements for HTTP/2
//if not then ALPN will not be attempted
String[]
protcols =
engine.
getEnabledProtocols();
boolean
found = false;
for (
String proto :
protcols) {
if (
proto.
equals(
REQUIRED_PROTOCOL)) {
found = true;
break;
}
}
if (!
found) {
return false;
}
String[]
ciphers =
engine.
getEnabledCipherSuites();
for (
String i :
ciphers) {
if (
i.
equals(
REQUIRED_CIPHER)) {
return true;
}
}
return false;
}
private class
AlpnConnectionListener implements
ChannelListener<
StreamSourceChannel> {
private final
StreamConnection channel;
private final
CompletableFuture<
SelectedAlpn>
selectedAlpn;
private
AlpnConnectionListener(
StreamConnection channel,
CompletableFuture<
SelectedAlpn>
selectedAlpn) {
this.
channel =
channel;
this.
selectedAlpn =
selectedAlpn;
}
@
Override
public void
handleEvent(
StreamSourceChannel source) {
PooledByteBuffer buffer =
bufferPool.
allocate();
boolean
free = true;
try {
while (true) {
int
res =
channel.
getSourceChannel().
read(
buffer.
getBuffer());
if (
res == -1) {
IoUtils.
safeClose(
channel);
return;
}
buffer.
getBuffer().
flip();
SelectedAlpn selectedAlpn = this.
selectedAlpn.
getNow(null);
final
String selected;
if (
selectedAlpn != null) {
selected =
selectedAlpn.
provider.
getSelectedProtocol(
selectedAlpn.
engine);
} else {
selected = null;
}
if (
selected != null) {
DelegateOpenListener listener;
if (
selected.
isEmpty()) {
//alpn not in use
if (
fallbackProtocol == null) {
UndertowLogger.
REQUEST_IO_LOGGER.
noALPNFallback(
channel.
getPeerAddress());
IoUtils.
safeClose(
channel);
return;
}
listener =
listeners.
get(
fallbackProtocol).
listener;
} else {
listener =
listeners.
get(
selected).
listener;
}
source.
getReadSetter().
set(null);
listener.
handleEvent(
channel,
buffer);
free = false;
return;
} else if (
res > 0) {
if (
fallbackProtocol == null) {
UndertowLogger.
REQUEST_IO_LOGGER.
noALPNFallback(
channel.
getPeerAddress());
IoUtils.
safeClose(
channel);
return;
}
DelegateOpenListener listener =
listeners.
get(
fallbackProtocol).
listener;
source.
getReadSetter().
set(null);
listener.
handleEvent(
channel,
buffer);
free = false;
return;
} else if (
res == 0) {
channel.
getSourceChannel().
resumeReads();
return;
}
}
} catch (
IOException e) {
UndertowLogger.
REQUEST_IO_LOGGER.
ioException(
e);
IoUtils.
safeClose(
channel);
} catch (
Throwable t) {
UndertowLogger.
REQUEST_IO_LOGGER.
handleUnexpectedFailure(
t);
IoUtils.
safeClose(
channel);
} finally {
if (
free) {
buffer.
close();
}
}
}
}
static final class
SelectedAlpn {
final
SSLEngine engine;
final
ALPNProvider provider;
SelectedAlpn(
SSLEngine engine,
ALPNProvider provider) {
this.
engine =
engine;
this.
provider =
provider;
}
}
static final class
SSLConduitUpdater implements
Function<
SSLEngine,
SSLEngine> {
final
SslConduit conduit;
final
Function<
SSLEngine,
SSLEngine>
underlying;
SSLConduitUpdater(
SslConduit conduit,
Function<
SSLEngine,
SSLEngine>
underlying) {
this.
conduit =
conduit;
this.
underlying =
underlying;
}
@
Override
public
SSLEngine apply(
SSLEngine engine) {
SSLEngine res =
underlying.
apply(
engine);
conduit.
setSslEngine(
res);
return
res;
}
}
}