/*
* Copyright (c) 2016 Mockito contributors
* This program is made available under the terms of the MIT License.
*/
package org.mockito.internal.creation.bytebuddy;
import net.bytebuddy.asm.
Advice;
import net.bytebuddy.description.method.
MethodDescription;
import net.bytebuddy.description.type.
TypeDescription;
import net.bytebuddy.dynamic.scaffold.
MethodGraph;
import net.bytebuddy.implementation.bind.annotation.
Argument;
import net.bytebuddy.implementation.bind.annotation.
This;
import net.bytebuddy.implementation.bytecode.assign.
Assigner;
import org.mockito.exceptions.base.
MockitoException;
import org.mockito.internal.debugging.
LocationImpl;
import org.mockito.internal.exceptions.stacktrace.
ConditionalStackTraceFilter;
import org.mockito.internal.invocation.
RealMethod;
import org.mockito.internal.invocation.
SerializableMethod;
import org.mockito.internal.invocation.mockref.
MockReference;
import org.mockito.internal.invocation.mockref.
MockWeakReference;
import org.mockito.internal.util.concurrent.
WeakConcurrentMap;
import java.io.
IOException;
import java.io.
ObjectInputStream;
import java.io.
Serializable;
import java.lang.annotation.
Retention;
import java.lang.annotation.
RetentionPolicy;
import java.lang.ref.
SoftReference;
import java.lang.reflect.
InvocationTargetException;
import java.lang.reflect.
Method;
import java.lang.reflect.
Modifier;
import java.util.
ArrayList;
import java.util.
List;
import java.util.concurrent.
Callable;
public class
MockMethodAdvice extends
MockMethodDispatcher {
final
WeakConcurrentMap<
Object,
MockMethodInterceptor>
interceptors;
private final
String identifier;
private final
SelfCallInfo selfCallInfo = new
SelfCallInfo();
private final
MethodGraph.
Compiler compiler =
MethodGraph.
Compiler.
Default.
forJavaHierarchy();
private final
WeakConcurrentMap<
Class<?>,
SoftReference<
MethodGraph>>
graphs
= new
WeakConcurrentMap.
WithInlinedExpunction<
Class<?>,
SoftReference<
MethodGraph>>();
public
MockMethodAdvice(
WeakConcurrentMap<
Object,
MockMethodInterceptor>
interceptors,
String identifier) {
this.
interceptors =
interceptors;
this.
identifier =
identifier;
}
@
SuppressWarnings("unused")
@
Advice.
OnMethodEnter(skipOn =
Advice.
OnNonDefaultValue.class)
private static
Callable<?>
enter(@
Identifier String identifier,
@
Advice.
This Object mock,
@
Advice.
Origin Method origin,
@
Advice.
AllArguments Object[]
arguments) throws
Throwable {
MockMethodDispatcher dispatcher =
MockMethodDispatcher.
get(
identifier,
mock);
if (
dispatcher == null || !
dispatcher.
isMocked(
mock) ||
dispatcher.
isOverridden(
mock,
origin)) {
return null;
} else {
return
dispatcher.
handle(
mock,
origin,
arguments);
}
}
@
SuppressWarnings({"unused", "UnusedAssignment"})
@
Advice.
OnMethodExit
private static void
exit(@
Advice.
Return(readOnly = false, typing =
Assigner.
Typing.
DYNAMIC)
Object returned,
@
Advice.
Enter Callable<?>
mocked) throws
Throwable {
if (
mocked != null) {
returned =
mocked.
call();
}
}
static
Throwable hideRecursiveCall(
Throwable throwable, int
current,
Class<?>
targetType) {
try {
StackTraceElement[]
stack =
throwable.
getStackTrace();
int
skip = 0;
StackTraceElement next;
do {
next =
stack[
stack.length -
current - ++
skip];
} while (!
next.
getClassName().
equals(
targetType.
getName()));
int
top =
stack.length -
current -
skip;
StackTraceElement[]
cleared = new
StackTraceElement[
stack.length -
skip];
System.
arraycopy(
stack, 0,
cleared, 0,
top);
System.
arraycopy(
stack,
top +
skip,
cleared,
top,
current);
throwable.
setStackTrace(
cleared);
return
throwable;
} catch (
RuntimeException ignored) {
// This should not happen unless someone instrumented or manipulated exception stack traces.
return
throwable;
}
}
@
Override
public
Callable<?>
handle(
Object instance,
Method origin,
Object[]
arguments) throws
Throwable {
MockMethodInterceptor interceptor =
interceptors.
get(
instance);
if (
interceptor == null) {
return null;
}
RealMethod realMethod;
if (
instance instanceof
Serializable) {
realMethod = new
SerializableRealMethodCall(
identifier,
origin,
instance,
arguments);
} else {
realMethod = new
RealMethodCall(
selfCallInfo,
origin,
instance,
arguments);
}
Throwable t = new
Throwable();
t.
setStackTrace(
skipInlineMethodElement(
t.
getStackTrace()));
return new
ReturnValueWrapper(
interceptor.
doIntercept(
instance,
origin,
arguments,
realMethod,
new
LocationImpl(
t)));
}
@
Override
public boolean
isMock(
Object instance) {
// We need to exclude 'interceptors.target' explicitly to avoid a recursive check on whether
// the map is a mock object what requires reading from the map.
return
instance !=
interceptors.
target &&
interceptors.
containsKey(
instance);
}
@
Override
public boolean
isMocked(
Object instance) {
return
selfCallInfo.
checkSuperCall(
instance) &&
isMock(
instance);
}
@
Override
public boolean
isOverridden(
Object instance,
Method origin) {
SoftReference<
MethodGraph>
reference =
graphs.
get(
instance.
getClass());
MethodGraph methodGraph =
reference == null ? null :
reference.
get();
if (
methodGraph == null) {
methodGraph =
compiler.
compile(new
TypeDescription.
ForLoadedType(
instance.
getClass()));
graphs.
put(
instance.
getClass(), new
SoftReference<
MethodGraph>(
methodGraph));
}
MethodGraph.
Node node =
methodGraph.
locate(new
MethodDescription.
ForLoadedMethod(
origin).
asSignatureToken());
return !
node.
getSort().
isResolved() || !
node.
getRepresentative().
asDefined().
getDeclaringType().
represents(
origin.
getDeclaringClass());
}
private static class
RealMethodCall implements
RealMethod {
private final
SelfCallInfo selfCallInfo;
private final
Method origin;
private final
MockWeakReference<
Object>
instanceRef;
private final
Object[]
arguments;
private
RealMethodCall(
SelfCallInfo selfCallInfo,
Method origin,
Object instance,
Object[]
arguments) {
this.
selfCallInfo =
selfCallInfo;
this.
origin =
origin;
this.
instanceRef = new
MockWeakReference<
Object>(
instance);
this.
arguments =
arguments;
}
@
Override
public boolean
isInvokable() {
return true;
}
@
Override
public
Object invoke() throws
Throwable {
if (!
Modifier.
isPublic(
origin.
getDeclaringClass().
getModifiers() &
origin.
getModifiers())) {
origin.
setAccessible(true);
}
selfCallInfo.
set(
instanceRef.
get());
return
tryInvoke(
origin,
instanceRef.
get(),
arguments);
}
}
private static class
SerializableRealMethodCall implements
RealMethod {
private final
String identifier;
private final
SerializableMethod origin;
private final
MockReference<
Object>
instanceRef;
private final
Object[]
arguments;
private
SerializableRealMethodCall(
String identifier,
Method origin,
Object instance,
Object[]
arguments) {
this.
origin = new
SerializableMethod(
origin);
this.
identifier =
identifier;
this.
instanceRef = new
MockWeakReference<
Object>(
instance);
this.
arguments =
arguments;
}
@
Override
public boolean
isInvokable() {
return true;
}
@
Override
public
Object invoke() throws
Throwable {
Method method =
origin.
getJavaMethod();
if (!
Modifier.
isPublic(
method.
getDeclaringClass().
getModifiers() &
method.
getModifiers())) {
method.
setAccessible(true);
}
MockMethodDispatcher mockMethodDispatcher =
MockMethodDispatcher.
get(
identifier,
instanceRef.
get());
if (!(
mockMethodDispatcher instanceof
MockMethodAdvice)) {
throw new
MockitoException("Unexpected dispatcher for advice-based super call");
}
Object previous = ((
MockMethodAdvice)
mockMethodDispatcher).
selfCallInfo.
replace(
instanceRef.
get());
try {
return
tryInvoke(
method,
instanceRef.
get(),
arguments);
} finally {
((
MockMethodAdvice)
mockMethodDispatcher).
selfCallInfo.
set(
previous);
}
}
}
private static
Object tryInvoke(
Method origin,
Object instance,
Object[]
arguments) throws
Throwable {
try {
return
origin.
invoke(
instance,
arguments);
} catch (
InvocationTargetException exception) {
Throwable cause =
exception.
getCause();
new
ConditionalStackTraceFilter().
filter(
hideRecursiveCall(
cause, new
Throwable().
getStackTrace().length,
origin.
getDeclaringClass()));
throw
cause;
}
}
// With inline mocking, mocks for concrete classes are not subclassed, so elements of the stubbing methods are not filtered out.
// Therefore, if the method is inlined, skip the element.
private static
StackTraceElement[]
skipInlineMethodElement(
StackTraceElement[]
elements) {
List<
StackTraceElement>
list = new
ArrayList<
StackTraceElement>(
elements.length);
for (int
i = 0;
i <
elements.length;
i++) {
StackTraceElement element =
elements[
i];
list.
add(
element);
if (
element.
getClassName().
equals(
MockMethodAdvice.class.
getName()) &&
element.
getMethodName().
equals("handle")) {
// If the current element is MockMethodAdvice#handle(), the next is assumed to be an inlined method.
i++;
}
}
return
list.
toArray(new
StackTraceElement[
list.
size()]);
}
private static class
ReturnValueWrapper implements
Callable<
Object> {
private final
Object returned;
private
ReturnValueWrapper(
Object returned) {
this.
returned =
returned;
}
@
Override
public
Object call() {
return
returned;
}
}
private static class
SelfCallInfo extends
ThreadLocal<
Object> {
Object replace(
Object value) {
Object current =
get();
set(
value);
return
current;
}
boolean
checkSuperCall(
Object value) {
if (
value ==
get()) {
set(null);
return false;
} else {
return true;
}
}
}
@
Retention(
RetentionPolicy.
RUNTIME)
@interface
Identifier {
}
static class
ForHashCode {
@
SuppressWarnings("unused")
@
Advice.
OnMethodEnter(skipOn =
Advice.
OnNonDefaultValue.class)
private static boolean
enter(@
Identifier String id,
@
Advice.
This Object self) {
MockMethodDispatcher dispatcher =
MockMethodDispatcher.
get(
id,
self);
return
dispatcher != null &&
dispatcher.
isMock(
self);
}
@
SuppressWarnings({"unused", "UnusedAssignment"})
@
Advice.
OnMethodExit
private static void
enter(@
Advice.
This Object self,
@
Advice.
Return(readOnly = false) int
hashCode,
@
Advice.
Enter boolean
skipped) {
if (
skipped) {
hashCode =
System.
identityHashCode(
self);
}
}
}
static class
ForEquals {
@
SuppressWarnings("unused")
@
Advice.
OnMethodEnter(skipOn =
Advice.
OnNonDefaultValue.class)
private static boolean
enter(@
Identifier String identifier,
@
Advice.
This Object self) {
MockMethodDispatcher dispatcher =
MockMethodDispatcher.
get(
identifier,
self);
return
dispatcher != null &&
dispatcher.
isMock(
self);
}
@
SuppressWarnings({"unused", "UnusedAssignment"})
@
Advice.
OnMethodExit
private static void
enter(@
Advice.
This Object self,
@
Advice.
Argument(0)
Object other,
@
Advice.
Return(readOnly = false) boolean
equals,
@
Advice.
Enter boolean
skipped) {
if (
skipped) {
equals =
self ==
other;
}
}
}
public static class
ForReadObject {
@
SuppressWarnings("unused")
public static void
doReadObject(@
Identifier String identifier,
@
This MockAccess thiz,
@
Argument(0)
ObjectInputStream objectInputStream) throws
IOException,
ClassNotFoundException {
objectInputStream.
defaultReadObject();
MockMethodAdvice mockMethodAdvice = (
MockMethodAdvice)
MockMethodDispatcher.
get(
identifier,
thiz);
if (
mockMethodAdvice != null) {
mockMethodAdvice.
interceptors.
put(
thiz,
thiz.
getMockitoInterceptor());
}
}
}
}