/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.io;
import java.io.
IOException;
import java.io.
OutputStream;
import java.nio.
ByteBuffer;
import java.nio.channels.
FileChannel;
import io.undertow.
UndertowMessages;
import io.undertow.server.
HttpServerExchange;
import io.undertow.util.
Headers;
import org.xnio.
Buffers;
import io.undertow.connector.
ByteBufferPool;
import io.undertow.connector.
PooledByteBuffer;
import org.xnio.channels.
Channels;
import org.xnio.channels.
StreamSinkChannel;
import static org.xnio.
Bits.anyAreClear;
import static org.xnio.
Bits.anyAreSet;
/**
* Buffering output stream that wraps a channel.
* <p>
* This stream delays channel creation, so if a response will fit in the buffer it is not necessary to
* set the content length header.
*
* @author Stuart Douglas
*/
public class
UndertowOutputStream extends
OutputStream implements
BufferWritableOutputStream {
private final
HttpServerExchange exchange;
private
ByteBuffer buffer;
private
PooledByteBuffer pooledBuffer;
private
StreamSinkChannel channel;
private int
state;
private long
written;
private final long
contentLength;
private static final int
FLAG_CLOSED = 1;
private static final int
FLAG_WRITE_STARTED = 1 << 1;
private static final int
MAX_BUFFERS_TO_ALLOCATE = 10;
/**
* Construct a new instance. No write timeout is configured.
*
* @param exchange The exchange
*/
public
UndertowOutputStream(
HttpServerExchange exchange) {
this.
exchange =
exchange;
this.
contentLength =
exchange.
getResponseContentLength();
}
/**
* If the response has not yet been written to the client this method will clear the streams buffer,
* invalidating any content that has already been written. If any content has already been sent to the client then
* this method will throw and IllegalStateException
*
* @throws java.lang.IllegalStateException If the response has been commited
*/
public void
resetBuffer() {
if(
anyAreSet(
state,
FLAG_WRITE_STARTED)) {
throw
UndertowMessages.
MESSAGES.
cannotResetBuffer();
}
if(
pooledBuffer != null) {
pooledBuffer.
close();
pooledBuffer = null;
}
}
public long
getBytesWritten() {
return
written;
}
/**
* {@inheritDoc}
*/
public void
write(final int
b) throws
IOException {
write(new byte[]{(byte)
b}, 0, 1);
}
/**
* {@inheritDoc}
*/
public void
write(final byte[]
b) throws
IOException {
write(
b, 0,
b.length);
}
/**
* {@inheritDoc}
*/
public void
write(final byte[]
b, final int
off, final int
len) throws
IOException {
if (
len < 1) {
return;
}
if(
Thread.
currentThread() ==
exchange.
getIoThread()) {
throw
UndertowMessages.
MESSAGES.
blockingIoFromIOThread();
}
if (
anyAreSet(
state,
FLAG_CLOSED)) {
throw
UndertowMessages.
MESSAGES.
streamIsClosed();
}
//if this is the last of the content
ByteBuffer buffer =
buffer();
if (
len ==
contentLength -
written ||
buffer.
remaining() <
len) {
if (
buffer.
remaining() <
len) {
//so what we have will not fit.
//We allocate multiple buffers up to MAX_BUFFERS_TO_ALLOCATE
//and put it in them
//if it still dopes not fit we loop, re-using these buffers
StreamSinkChannel channel = this.
channel;
if (
channel == null) {
this.
channel =
channel =
exchange.
getResponseChannel();
}
final
ByteBufferPool bufferPool =
exchange.
getConnection().
getByteBufferPool();
ByteBuffer[]
buffers = new
ByteBuffer[
MAX_BUFFERS_TO_ALLOCATE + 1];
PooledByteBuffer[]
pooledBuffers = new
PooledByteBuffer[
MAX_BUFFERS_TO_ALLOCATE];
try {
buffers[0] =
buffer;
int
bytesWritten = 0;
int
rem =
buffer.
remaining();
buffer.
put(
b,
bytesWritten +
off,
rem);
buffer.
flip();
bytesWritten +=
rem;
int
bufferCount = 1;
for (int
i = 0;
i <
MAX_BUFFERS_TO_ALLOCATE; ++
i) {
PooledByteBuffer pooled =
bufferPool.
allocate();
pooledBuffers[
bufferCount - 1] =
pooled;
buffers[
bufferCount++] =
pooled.
getBuffer();
ByteBuffer cb =
pooled.
getBuffer();
int
toWrite =
len -
bytesWritten;
if (
toWrite >
cb.
remaining()) {
rem =
cb.
remaining();
cb.
put(
b,
bytesWritten +
off,
rem);
cb.
flip();
bytesWritten +=
rem;
} else {
cb.
put(
b,
bytesWritten +
off,
len -
bytesWritten);
bytesWritten =
len;
cb.
flip();
break;
}
}
Channels.
writeBlocking(
channel,
buffers, 0,
bufferCount);
while (
bytesWritten <
len) {
//ok, it did not fit, loop and loop and loop until it is done
bufferCount = 0;
for (int
i = 0;
i <
MAX_BUFFERS_TO_ALLOCATE + 1; ++
i) {
ByteBuffer cb =
buffers[
i];
cb.
clear();
bufferCount++;
int
toWrite =
len -
bytesWritten;
if (
toWrite >
cb.
remaining()) {
rem =
cb.
remaining();
cb.
put(
b,
bytesWritten +
off,
rem);
cb.
flip();
bytesWritten +=
rem;
} else {
cb.
put(
b,
bytesWritten +
off,
len -
bytesWritten);
bytesWritten =
len;
cb.
flip();
break;
}
}
Channels.
writeBlocking(
channel,
buffers, 0,
bufferCount);
}
buffer.
clear();
} finally {
for (int
i = 0;
i <
pooledBuffers.length; ++
i) {
PooledByteBuffer p =
pooledBuffers[
i];
if (
p == null) {
break;
}
p.
close();
}
}
} else {
buffer.
put(
b,
off,
len);
if (
buffer.
remaining() == 0) {
writeBufferBlocking(false);
}
}
} else {
buffer.
put(
b,
off,
len);
if (
buffer.
remaining() == 0) {
writeBufferBlocking(false);
}
}
updateWritten(
len);
}
@
Override
public void
write(
ByteBuffer[]
buffers) throws
IOException {
if (
anyAreSet(
state,
FLAG_CLOSED)) {
throw
UndertowMessages.
MESSAGES.
streamIsClosed();
}
int
len = 0;
for (
ByteBuffer buf :
buffers) {
len +=
buf.
remaining();
}
if (
len < 1) {
return;
}
//if we have received the exact amount of content write it out in one go
//this is a common case when writing directly from a buffer cache.
if (this.
written == 0 &&
len ==
contentLength) {
if (
channel == null) {
channel =
exchange.
getResponseChannel();
}
Channels.
writeBlocking(
channel,
buffers, 0,
buffers.length);
state |=
FLAG_WRITE_STARTED;
} else {
ByteBuffer buffer =
buffer();
if (
len <
buffer.
remaining()) {
Buffers.
copy(
buffer,
buffers, 0,
buffers.length);
} else {
if (
channel == null) {
channel =
exchange.
getResponseChannel();
}
if (
buffer.
position() == 0) {
Channels.
writeBlocking(
channel,
buffers, 0,
buffers.length);
} else {
final
ByteBuffer[]
newBuffers = new
ByteBuffer[
buffers.length + 1];
buffer.
flip();
newBuffers[0] =
buffer;
System.
arraycopy(
buffers, 0,
newBuffers, 1,
buffers.length);
Channels.
writeBlocking(
channel,
newBuffers, 0,
newBuffers.length);
buffer.
clear();
}
state |=
FLAG_WRITE_STARTED;
}
}
updateWritten(
len);
}
@
Override
public void
write(
ByteBuffer byteBuffer) throws
IOException {
write(new
ByteBuffer[]{
byteBuffer});
}
void
updateWritten(final long
len) throws
IOException {
this.
written +=
len;
if (
contentLength != -1 && this.
written >=
contentLength) {
flush();
close();
}
}
/**
* {@inheritDoc}
*/
public void
flush() throws
IOException {
if (
anyAreSet(
state,
FLAG_CLOSED)) {
throw
UndertowMessages.
MESSAGES.
streamIsClosed();
}
if (
buffer != null &&
buffer.
position() != 0) {
writeBufferBlocking(false);
}
if (
channel == null) {
channel =
exchange.
getResponseChannel();
}
Channels.
flushBlocking(
channel);
}
private void
writeBufferBlocking(final boolean
writeFinal) throws
IOException {
if (
channel == null) {
channel =
exchange.
getResponseChannel();
}
buffer.
flip();
while (
buffer.
hasRemaining()) {
if(
writeFinal) {
channel.
writeFinal(
buffer);
} else {
channel.
write(
buffer);
}
if(
buffer.
hasRemaining()) {
channel.
awaitWritable();
}
}
buffer.
clear();
state |=
FLAG_WRITE_STARTED;
}
@
Override
public void
transferFrom(
FileChannel source) throws
IOException {
if (
anyAreSet(
state,
FLAG_CLOSED)) {
throw
UndertowMessages.
MESSAGES.
streamIsClosed();
}
if (
buffer != null &&
buffer.
position() != 0) {
writeBufferBlocking(false);
}
if (
channel == null) {
channel =
exchange.
getResponseChannel();
}
long
position =
source.
position();
long
size =
source.
size();
Channels.
transferBlocking(
channel,
source,
position,
size);
updateWritten(
size -
position);
}
/**
* {@inheritDoc}
*/
public void
close() throws
IOException {
if (
anyAreSet(
state,
FLAG_CLOSED)) return;
try {
state |=
FLAG_CLOSED;
if (
anyAreClear(
state,
FLAG_WRITE_STARTED) &&
channel == null) {
if (
buffer == null) {
exchange.
getResponseHeaders().
put(
Headers.
CONTENT_LENGTH, "0");
} else {
exchange.
getResponseHeaders().
put(
Headers.
CONTENT_LENGTH, "" +
buffer.
position());
}
}
if (
buffer != null) {
writeBufferBlocking(true);
}
if (
channel == null) {
channel =
exchange.
getResponseChannel();
}
if(
channel == null) {
return;
}
StreamSinkChannel channel = this.
channel;
channel.
shutdownWrites();
Channels.
flushBlocking(
channel);
} finally {
if (
pooledBuffer != null) {
pooledBuffer.
close();
buffer = null;
} else {
buffer = null;
}
}
}
private
ByteBuffer buffer() {
ByteBuffer buffer = this.
buffer;
if (
buffer != null) {
return
buffer;
}
this.
pooledBuffer =
exchange.
getConnection().
getByteBufferPool().
allocate();
this.
buffer =
pooledBuffer.
getBuffer();
return this.
buffer;
}
}