/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.conduits;
import static org.xnio.
Bits.allAreClear;
import static org.xnio.
Bits.allAreSet;
import static org.xnio.
Bits.anyAreSet;
import java.io.
IOException;
import java.nio.
ByteBuffer;
import java.nio.channels.
ClosedChannelException;
import java.nio.channels.
FileChannel;
import java.util.concurrent.
TimeUnit;
import java.util.zip.
Deflater;
import io.undertow.server.
Connectors;
import org.xnio.
IoUtils;
import io.undertow.connector.
PooledByteBuffer;
import org.xnio.
XnioIoThread;
import org.xnio.
XnioWorker;
import org.xnio.channels.
StreamSourceChannel;
import org.xnio.conduits.
ConduitWritableByteChannel;
import org.xnio.conduits.
Conduits;
import org.xnio.conduits.
StreamSinkConduit;
import org.xnio.conduits.
WriteReadyHandler;
import io.undertow.
UndertowLogger;
import io.undertow.server.
HttpServerExchange;
import io.undertow.util.
ConduitFactory;
import io.undertow.util.
NewInstanceObjectPool;
import io.undertow.util.
ObjectPool;
import io.undertow.util.
Headers;
import io.undertow.util.
PooledObject;
import io.undertow.util.
SimpleObjectPool;
/**
* Channel that handles deflate compression
*
* @author Stuart Douglas
*/
public class
DeflatingStreamSinkConduit implements
StreamSinkConduit {
protected volatile
Deflater deflater;
protected final
PooledObject<
Deflater>
pooledObject;
private final
ConduitFactory<
StreamSinkConduit>
conduitFactory;
private final
HttpServerExchange exchange;
private
StreamSinkConduit next;
private
WriteReadyHandler writeReadyHandler;
/**
* The streams buffer. This is freed when the next is shutdown
*/
protected
PooledByteBuffer currentBuffer;
/**
* there may have been some additional data that did not fit into the first buffer
*/
private
ByteBuffer additionalBuffer;
private int
state = 0;
private static final int
SHUTDOWN = 1;
private static final int
NEXT_SHUTDOWN = 1 << 1;
private static final int
FLUSHING_BUFFER = 1 << 2;
private static final int
WRITES_RESUMED = 1 << 3;
private static final int
CLOSED = 1 << 4;
private static final int
WRITTEN_TRAILER = 1 << 5;
public
DeflatingStreamSinkConduit(final
ConduitFactory<
StreamSinkConduit>
conduitFactory, final
HttpServerExchange exchange) {
this(
conduitFactory,
exchange,
Deflater.
DEFLATED);
}
public
DeflatingStreamSinkConduit(final
ConduitFactory<
StreamSinkConduit>
conduitFactory, final
HttpServerExchange exchange, int
deflateLevel) {
this(
conduitFactory,
exchange,
newInstanceDeflaterPool(
deflateLevel));
}
public
DeflatingStreamSinkConduit(final
ConduitFactory<
StreamSinkConduit>
conduitFactory, final
HttpServerExchange exchange,
ObjectPool<
Deflater>
deflaterPool) {
this.
pooledObject =
deflaterPool.
allocate();
this.
deflater =
pooledObject.
getObject();
this.
currentBuffer =
exchange.
getConnection().
getByteBufferPool().
allocate();
this.
exchange =
exchange;
this.
conduitFactory =
conduitFactory;
setWriteReadyHandler(new
WriteReadyHandler.
ChannelListenerHandler<>(
Connectors.
getConduitSinkChannel(
exchange)));
}
public static
ObjectPool<
Deflater>
newInstanceDeflaterPool(int
deflateLevel) {
return new
NewInstanceObjectPool<
Deflater>(() -> new
Deflater(
deflateLevel, true),
Deflater::end);
}
public static
ObjectPool<
Deflater>
simpleDeflaterPool(int
poolSize, int
deflateLevel) {
return new
SimpleObjectPool<
Deflater>(
poolSize, () -> new
Deflater(
deflateLevel, true),
Deflater::reset,
Deflater::end);
}
@
Override
public int
write(final
ByteBuffer src) throws
IOException {
if (
anyAreSet(
state,
SHUTDOWN |
CLOSED) ||
currentBuffer == null) {
throw new
ClosedChannelException();
}
try {
if (!
performFlushIfRequired()) {
return 0;
}
if (
src.
remaining() == 0) {
return 0;
}
//we may already have some input, if so compress it
if (!
deflater.
needsInput()) {
deflateData(false);
if (!
deflater.
needsInput()) {
return 0;
}
}
byte[]
data = new byte[
src.
remaining()];
src.
get(
data);
preDeflate(
data);
deflater.
setInput(
data);
Connectors.
updateResponseBytesSent(
exchange, 0 -
data.length);
deflateData(false);
return
data.length;
} catch (
IOException |
RuntimeException |
Error e) {
freeBuffer();
throw
e;
}
}
protected void
preDeflate(byte[]
data) {
}
@
Override
public long
write(final
ByteBuffer[]
srcs, final int
offset, final int
length) throws
IOException {
if (
anyAreSet(
state,
SHUTDOWN |
CLOSED) ||
currentBuffer == null) {
throw new
ClosedChannelException();
}
try {
int
total = 0;
for (int
i =
offset;
i <
offset +
length; ++
i) {
if (
srcs[
i].
hasRemaining()) {
int
ret =
write(
srcs[
i]);
total +=
ret;
if (
ret == 0) {
return
total;
}
}
}
return
total;
} catch (
IOException |
RuntimeException |
Error e) {
freeBuffer();
throw
e;
}
}
@
Override
public int
writeFinal(
ByteBuffer src) throws
IOException {
return
Conduits.
writeFinalBasic(this,
src);
}
@
Override
public long
writeFinal(
ByteBuffer[]
srcs, int
offset, int
length) throws
IOException {
return
Conduits.
writeFinalBasic(this,
srcs,
offset,
length);
}
@
Override
public long
transferFrom(final
FileChannel src, final long
position, final long
count) throws
IOException {
if (
anyAreSet(
state,
SHUTDOWN |
CLOSED)) {
throw new
ClosedChannelException();
}
if (!
performFlushIfRequired()) {
return 0;
}
return
src.
transferTo(
position,
count, new
ConduitWritableByteChannel(this));
}
@
Override
public long
transferFrom(final
StreamSourceChannel source, final long
count, final
ByteBuffer throughBuffer) throws
IOException {
if (
anyAreSet(
state,
SHUTDOWN |
CLOSED)) {
throw new
ClosedChannelException();
}
if (!
performFlushIfRequired()) {
return 0;
}
return
IoUtils.
transfer(
source,
count,
throughBuffer, new
ConduitWritableByteChannel(this));
}
@
Override
public
XnioWorker getWorker() {
return
exchange.
getConnection().
getWorker();
}
@
Override
public void
suspendWrites() {
if (
next == null) {
state =
state & ~
WRITES_RESUMED;
} else {
next.
suspendWrites();
}
}
@
Override
public boolean
isWriteResumed() {
if (
next == null) {
return
anyAreSet(
state,
WRITES_RESUMED);
} else {
return
next.
isWriteResumed();
}
}
@
Override
public void
wakeupWrites() {
if (
next == null) {
resumeWrites();
} else {
next.
wakeupWrites();
}
}
@
Override
public void
resumeWrites() {
if (
next == null) {
state |=
WRITES_RESUMED;
queueWriteListener();
} else {
next.
resumeWrites();
}
}
private void
queueWriteListener() {
exchange.
getConnection().
getIoThread().
execute(new
Runnable() {
@
Override
public void
run() {
if (
writeReadyHandler != null) {
try {
writeReadyHandler.
writeReady();
} finally {
//if writes are still resumed queue up another one
if (
next == null &&
isWriteResumed()) {
queueWriteListener();
}
}
}
}
});
}
@
Override
public void
terminateWrites() throws
IOException {
if (
deflater != null) {
deflater.
finish();
}
state |=
SHUTDOWN;
}
@
Override
public boolean
isWriteShutdown() {
return
anyAreSet(
state,
SHUTDOWN);
}
@
Override
public void
awaitWritable() throws
IOException {
if (
next == null) {
return;
} else {
next.
awaitWritable();
}
}
@
Override
public void
awaitWritable(final long
time, final
TimeUnit timeUnit) throws
IOException {
if (
next == null) {
return;
} else {
next.
awaitWritable(
time,
timeUnit);
}
}
@
Override
public
XnioIoThread getWriteThread() {
return
exchange.
getConnection().
getIoThread();
}
@
Override
public void
setWriteReadyHandler(final
WriteReadyHandler handler) {
this.
writeReadyHandler =
handler;
}
@
Override
public boolean
flush() throws
IOException {
if (
currentBuffer == null) {
if (
anyAreSet(
state,
NEXT_SHUTDOWN)) {
return
next.
flush();
} else {
return true;
}
}
try {
boolean
nextCreated = false;
try {
if (
anyAreSet(
state,
SHUTDOWN)) {
if (
anyAreSet(
state,
NEXT_SHUTDOWN)) {
return
next.
flush();
} else {
if (!
performFlushIfRequired()) {
return false;
}
//if the deflater has not been fully flushed we need to flush it
if (!
deflater.
finished()) {
deflateData(false);
//if could not fully flush
if (!
deflater.
finished()) {
return false;
}
}
final
ByteBuffer buffer =
currentBuffer.
getBuffer();
if (
allAreClear(
state,
WRITTEN_TRAILER)) {
state |=
WRITTEN_TRAILER;
byte[]
data =
getTrailer();
if (
data != null) {
Connectors.
updateResponseBytesSent(
exchange,
data.length);
if(
additionalBuffer != null) {
byte[]
newData = new byte[
additionalBuffer.
remaining() +
data.length];
int
pos = 0;
while (
additionalBuffer.
hasRemaining()) {
newData[
pos++] =
additionalBuffer.
get();
}
for (byte
aData :
data) {
newData[
pos++] =
aData;
}
this.
additionalBuffer =
ByteBuffer.
wrap(
newData);
} else if(
anyAreSet(
state,
FLUSHING_BUFFER) &&
buffer.
capacity() -
buffer.
remaining() >=
data.length) {
buffer.
compact();
buffer.
put(
data);
buffer.
flip();
} else if (
data.length <=
buffer.
remaining() && !
anyAreSet(
state,
FLUSHING_BUFFER)) {
buffer.
put(
data);
} else {
additionalBuffer =
ByteBuffer.
wrap(
data);
}
}
}
//ok the deflater is flushed, now we need to flush the buffer
if (!
anyAreSet(
state,
FLUSHING_BUFFER)) {
buffer.
flip();
state |=
FLUSHING_BUFFER;
if (
next == null) {
nextCreated = true;
this.
next =
createNextChannel();
}
}
if (
performFlushIfRequired()) {
state |=
NEXT_SHUTDOWN;
freeBuffer();
next.
terminateWrites();
return
next.
flush();
} else {
return false;
}
}
} else {
if(
allAreClear(
state,
FLUSHING_BUFFER)) {
if (
next == null) {
nextCreated = true;
this.
next =
createNextChannel();
}
deflateData(true);
if(
allAreClear(
state,
FLUSHING_BUFFER)) {
//deflateData can cause this to be change
currentBuffer.
getBuffer().
flip();
this.
state |=
FLUSHING_BUFFER;
}
}
if(!
performFlushIfRequired()) {
return false;
}
return
next.
flush();
}
} finally {
if (
nextCreated) {
if (
anyAreSet(
state,
WRITES_RESUMED) && !
anyAreSet(
state ,
NEXT_SHUTDOWN)) {
try {
next.
resumeWrites();
} catch (
Throwable e) {
UndertowLogger.
REQUEST_LOGGER.
debug("Failed to resume",
e);
}
}
}
}
} catch (
IOException |
RuntimeException |
Error e) {
freeBuffer();
throw
e;
}
}
/**
* called before the stream is finally flushed.
*/
protected byte[]
getTrailer() {
return null;
}
/**
* The we are in the flushing state then we flush to the underlying stream, otherwise just return true
*
* @return false if there is still more to flush
*/
private boolean
performFlushIfRequired() throws
IOException {
if (
anyAreSet(
state,
FLUSHING_BUFFER)) {
final
ByteBuffer[]
bufs = new
ByteBuffer[
additionalBuffer == null ? 1 : 2];
long
totalLength = 0;
bufs[0] =
currentBuffer.
getBuffer();
totalLength +=
bufs[0].
remaining();
if (
additionalBuffer != null) {
bufs[1] =
additionalBuffer;
totalLength +=
bufs[1].
remaining();
}
if (
totalLength > 0) {
long
total = 0;
long
res = 0;
do {
res =
next.
write(
bufs, 0,
bufs.length);
total +=
res;
if (
res == 0) {
return false;
}
} while (
total <
totalLength);
}
additionalBuffer = null;
currentBuffer.
getBuffer().
clear();
state =
state & ~
FLUSHING_BUFFER;
}
return true;
}
private
StreamSinkConduit createNextChannel() {
if (
deflater.
finished() &&
allAreSet(
state,
WRITTEN_TRAILER)) {
//the deflater was fully flushed before we created the channel. This means that what is in the buffer is
//all there is
int
remaining =
currentBuffer.
getBuffer().
remaining();
if (
additionalBuffer != null) {
remaining +=
additionalBuffer.
remaining();
}
if(!
exchange.
getResponseHeaders().
contains(
Headers.
TRANSFER_ENCODING)) {
exchange.
getResponseHeaders().
put(
Headers.
CONTENT_LENGTH,
Integer.
toString(
remaining));
}
} else {
exchange.
getResponseHeaders().
remove(
Headers.
CONTENT_LENGTH);
}
return
conduitFactory.
create();
}
/**
* Runs the current data through the deflater. As much as possible this will be buffered in the current output
* stream.
*
* @throws IOException
*/
private void
deflateData(boolean
force) throws
IOException {
//we don't need to flush here, as this should have been called already by the time we get to
//this point
boolean
nextCreated = false;
try (
PooledByteBuffer arrayPooled = this.
exchange.
getConnection().
getByteBufferPool().
getArrayBackedPool().
allocate()) {
PooledByteBuffer pooled = this.
currentBuffer;
final
ByteBuffer outputBuffer =
pooled.
getBuffer();
final boolean
shutdown =
anyAreSet(
state,
SHUTDOWN);
ByteBuffer buf =
arrayPooled.
getBuffer();
while (
force || !
deflater.
needsInput() || (
shutdown && !
deflater.
finished())) {
int
count =
deflater.
deflate(
buf.
array(),
buf.
arrayOffset(),
buf.
remaining(),
force ?
Deflater.
SYNC_FLUSH:
Deflater.
NO_FLUSH);
Connectors.
updateResponseBytesSent(
exchange,
count);
if (
count != 0) {
int
remaining =
outputBuffer.
remaining();
if (
remaining >
count) {
outputBuffer.
put(
buf.
array(),
buf.
arrayOffset(),
count);
} else {
if (
remaining ==
count) {
outputBuffer.
put(
buf.
array(),
buf.
arrayOffset(),
count);
} else {
outputBuffer.
put(
buf.
array(),
buf.
arrayOffset(),
remaining);
additionalBuffer =
ByteBuffer.
allocate(
count -
remaining);
additionalBuffer.
put(
buf.
array(),
buf.
arrayOffset() +
remaining,
count -
remaining);
additionalBuffer.
flip();
}
outputBuffer.
flip();
this.
state |=
FLUSHING_BUFFER;
if (
next == null) {
nextCreated = true;
this.
next =
createNextChannel();
}
if (!
performFlushIfRequired()) {
return;
}
}
} else {
force = false;
}
}
} finally {
if (
nextCreated) {
if (
anyAreSet(
state,
WRITES_RESUMED)) {
next.
resumeWrites();
}
}
}
}
@
Override
public void
truncateWrites() throws
IOException {
freeBuffer();
state |=
CLOSED;
next.
truncateWrites();
}
private void
freeBuffer() {
if (
currentBuffer != null) {
currentBuffer.
close();
currentBuffer = null;
state =
state & ~
FLUSHING_BUFFER;
}
if (
deflater != null) {
pooledObject.
close();
deflater = null;
}
}
}