/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.conduits;
import org.xnio.channels.
StreamSourceChannel;
import org.xnio.conduits.
AbstractStreamSinkConduit;
import org.xnio.conduits.
StreamSinkConduit;
import java.io.
IOException;
import java.io.
InterruptedIOException;
import java.nio.
ByteBuffer;
import java.nio.channels.
FileChannel;
import java.util.concurrent.
TimeUnit;
import io.undertow.util.
WorkerUtils;
/**
* Class that implements the token bucket algorithm.
* <p>
* Allows send speed to be throttled
* <p>
* Note that throttling is applied after an initial write, so if a big write is performed initially
* it may be a while before it can write again.
*
* @author Stuart Douglas
*/
public class
RateLimitingStreamSinkConduit extends
AbstractStreamSinkConduit<
StreamSinkConduit> {
private final long
time;
private final int
bytes;
private boolean
writesResumed = false;
private int
byteCount = 0;
private long
startTime = 0;
private long
nextSendTime = 0;
private boolean
scheduled = false;
/**
* @param next The next conduit
* @param bytes The number of bytes that are allowed per time frame
* @param time The time frame
* @param timeUnit The time unit
*/
public
RateLimitingStreamSinkConduit(
StreamSinkConduit next, int
bytes, long
time,
TimeUnit timeUnit) {
super(
next);
writesResumed =
next.
isWriteResumed();
this.
time =
timeUnit.
toMillis(
time);
this.
bytes =
bytes;
}
@
Override
public int
write(
ByteBuffer src) throws
IOException {
if (!
canSend()) {
return 0;
}
int
bytes = this.
bytes - this.
byteCount;
int
old =
src.
limit();
if (
src.
remaining() >
bytes) {
src.
limit(
src.
position() +
bytes);
}
try {
int
written = super.write(
src);
handleWritten(
written);
return
written;
} finally {
src.
limit(
old);
}
}
@
Override
public long
transferFrom(
FileChannel src, long
position, long
count) throws
IOException {
if (!
canSend()) {
return 0;
}
int
bytes = this.
bytes - this.
byteCount;
long
written = super.transferFrom(
src,
position,
Math.
min(
count,
bytes));
handleWritten(
written);
return
written;
}
@
Override
public long
transferFrom(
StreamSourceChannel source, long
count,
ByteBuffer throughBuffer) throws
IOException {
if (!
canSend()) {
return 0;
}
int
bytes = this.
bytes - this.
byteCount;
long
written = super.transferFrom(
source,
Math.
min(
count,
bytes),
throughBuffer);
handleWritten(
written);
return
written;
}
@
Override
public long
write(
ByteBuffer[]
srcs, int
offs, int
len) throws
IOException {
if (!
canSend()) {
return 0;
}
int
old = 0;
int
adjPos = -1;
long
rem =
bytes -
byteCount;
for (int
i =
offs;
i <
offs +
len; ++
i) {
ByteBuffer buf =
srcs[
i];
rem -=
buf.
remaining();
if (
rem < 0) {
adjPos =
i;
old =
buf.
limit();
buf.
limit((int) (
buf.
limit() +
rem));
break;
}
}
try {
long
written;
if (
adjPos == -1) {
written = super.write(
srcs,
offs,
len);
} else {
written = super.write(
srcs,
offs,
adjPos -
offs + 1);
}
handleWritten(
written);
return
written;
} finally {
if (
adjPos != -1) {
ByteBuffer buf =
srcs[
adjPos];
buf.
limit(
old);
}
}
}
@
Override
public int
writeFinal(
ByteBuffer src) throws
IOException {
if (!
canSend()) {
return 0;
}
int
bytes = this.
bytes - this.
byteCount;
int
old =
src.
limit();
if (
src.
remaining() >
bytes) {
src.
limit(
src.
position() +
bytes);
}
try {
int
written = super.writeFinal(
src);
handleWritten(
written);
return
written;
} finally {
src.
limit(
old);
}
}
@
Override
public long
writeFinal(
ByteBuffer[]
srcs, int
offs, int
len) throws
IOException {
if (!
canSend()) {
return 0;
}
int
old = 0;
int
adjPos = -1;
long
rem =
bytes -
byteCount;
for (int
i =
offs;
i <
offs +
len; ++
i) {
ByteBuffer buf =
srcs[
i];
rem -=
buf.
remaining();
if (
rem < 0) {
adjPos =
i;
old =
buf.
limit();
buf.
limit((int) (
buf.
limit() +
rem));
break;
}
}
try {
long
written;
if (
adjPos == -1) {
written = super.writeFinal(
srcs,
offs,
len);
} else {
written = super.writeFinal(
srcs,
offs,
adjPos -
offs + 1);
}
handleWritten(
written);
return
written;
} finally {
if (
adjPos != -1) {
ByteBuffer buf =
srcs[
adjPos];
buf.
limit(
old);
}
}
}
@
Override
public void
resumeWrites() {
writesResumed = true;
if (
canSend()) {
super.resumeWrites();
}
}
@
Override
public void
suspendWrites() {
writesResumed = false;
super.suspendWrites();
}
@
Override
public void
wakeupWrites() {
writesResumed = true;
if (
canSend()) {
super.wakeupWrites();
}
}
@
Override
public boolean
isWriteResumed() {
return
writesResumed;
}
@
Override
public void
awaitWritable() throws
IOException {
long
toGo =
nextSendTime -
System.
currentTimeMillis();
if (
toGo > 0) {
try {
Thread.
sleep(
toGo);
} catch (
InterruptedException e) {
throw new
InterruptedIOException();
}
}
super.awaitWritable();
}
@
Override
public void
awaitWritable(long
time,
TimeUnit timeUnit) throws
IOException {
long
toGo =
nextSendTime -
System.
currentTimeMillis();
if (
toGo > 0) {
try {
Thread.
sleep(
Math.
min(
toGo,
timeUnit.
toMillis(
time)));
} catch (
InterruptedException e) {
throw new
InterruptedIOException();
}
return;
}
super.awaitWritable(
time,
timeUnit);
}
private boolean
canSend() {
if (
byteCount <
bytes) {
return true;
}
if (
System.
currentTimeMillis() >
nextSendTime) {
byteCount = 0;
startTime = 0;
nextSendTime = 0;
return true;
}
if (
writesResumed) {
handleWritesResumedWhenBlocked();
}
return false;
}
private void
handleWritten(long
written) {
if (
written == 0) {
return;
}
byteCount +=
written;
if (
byteCount <
bytes) {
//we are still allowed to send
if (
startTime == 0) {
startTime =
System.
currentTimeMillis();
nextSendTime =
System.
currentTimeMillis() +
time;
}
} else {
//we have gone over, we need to wait till we are allowed to send again
if (
startTime == 0) {
startTime =
System.
currentTimeMillis();
}
nextSendTime =
startTime +
time;
if (
writesResumed) {
handleWritesResumedWhenBlocked();
}
}
}
private void
handleWritesResumedWhenBlocked() {
if (
scheduled) {
return;
}
scheduled = true;
next.
suspendWrites();
long
millis =
nextSendTime -
System.
currentTimeMillis();
WorkerUtils.
executeAfter(
getWriteThread(), new
Runnable() {
@
Override
public void
run() {
scheduled = false;
if (
writesResumed) {
next.
wakeupWrites();
}
}
},
millis,
TimeUnit.
MILLISECONDS);
}
}