Skip to content
This repository was archived by the owner on Sep 26, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Respond to comments
  • Loading branch information
tonytanger committed Nov 19, 2019
commit 06a4db2f01f4a35865f9856882c2a3c2304fe742
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
import io.grpc.ManagedChannel;
import java.io.IOException;

/** This interface represents a factory for creating one ManagedChannel */
/**
* This interface represents a factory for creating one ManagedChannel
*
* <p>This is public only for technical reasons, for advanced usage.
*/
@InternalApi("For internal use by google-cloud-java clients only")
public interface ChannelFactory {
Comment thread
igorbernstein2 marked this conversation as resolved.
ManagedChannel createSingleChannel() throws IOException;
Expand Down
18 changes: 9 additions & 9 deletions gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;

/**
* A {@link ManagedChannel} that will send requests round robin via a set of channels.
Expand All @@ -53,39 +54,38 @@ class ChannelPool extends ManagedChannel {
/**
* Factory method to create a non-refreshing channel pool
*
* @see ChannelPool#ChannelPool(int, ChannelFactory, ScheduledExecutorService, boolean)
* @see ChannelPool#ChannelPool(int, ChannelFactory, boolean, ScheduledExecutorService)
*/
static ChannelPool create(int poolSize, final ChannelFactory channelFactory) throws IOException {
return new ChannelPool(poolSize, channelFactory, null, false);
return new ChannelPool(poolSize, channelFactory, false, null);
}

/**
* Factory method to create a refreshing channel pool
*
* @see ChannelPool#ChannelPool(int, ChannelFactory, ScheduledExecutorService, boolean)
* @see ChannelPool#ChannelPool(int, ChannelFactory, boolean, ScheduledExecutorService)
*/
static ChannelPool createRefreshing(
int poolSize, final ChannelFactory channelFactory, ScheduledExecutorService executorService)
throws IOException {
return new ChannelPool(poolSize, channelFactory, executorService, true);
return new ChannelPool(poolSize, channelFactory, true, executorService);
}

/**
* Initializes the channel pool. Assumes that all channels have the same authority.
*
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
* @param executorService if set, schedule periodically refresh the channels
* @param isRefreshing if true, channels will refresh themselves periodically
* @param executorService periodically refreshes the channels
*/
private ChannelPool(
int poolSize,
final ChannelFactory channelFactory,
ScheduledExecutorService executorService,
boolean isRefreshing)
boolean isRefreshing,
@Nullable ScheduledExecutorService executorService)
throws IOException {
channels = new ArrayList<>(poolSize);
// if executorService is available, create RefreshingManagedChannel that will get refreshed.
// otherwise create with regular ManagedChannel
if (isRefreshing) {
Comment thread
tonytanger marked this conversation as resolved.
Outdated
for (int i = 0; i < poolSize; i++) {
channels.add(new RefreshingManagedChannel(channelFactory, executorService));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
import com.google.api.core.InternalApi;
import io.grpc.ManagedChannel;

/** An interface to prepare a ManagedChannel for normal requests by priming the channel */
/**
* An interface to prepare a ManagedChannel for normal requests by priming the channel
*
* <p>This is public only for technical reasons, for advanced usage.
*/
@InternalApi("For internal use by google-cloud-java clients only")
Comment thread
igorbernstein2 marked this conversation as resolved.
public interface ChannelPrimer {
Comment thread
igorbernstein2 marked this conversation as resolved.
void primeChannel(ManagedChannel managedChannel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.auth.Credentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -193,28 +194,22 @@ public TransportChannel getTransportChannel() throws IOException {
}

private TransportChannel createChannel() throws IOException {
ManagedChannel outerChannel;

int realPoolSize;
if (poolSize == null) {
realPoolSize = 1;
} else {
realPoolSize = poolSize;
}
int realPoolSize = MoreObjects.firstNonNull(poolSize, 1);
ChannelFactory channelFactory =
new ChannelFactory() {
public ManagedChannel createSingleChannel() throws IOException {
return InstantiatingGrpcChannelProvider.this.createSingleChannel();
Comment thread
igorbernstein2 marked this conversation as resolved.
}
};
ManagedChannel outerChannel;
if (channelPrimer != null) {
outerChannel =
ChannelPool.createRefreshing(
realPoolSize, channelFactory, executorProvider.getExecutor());
} else {
outerChannel = ChannelPool.create(realPoolSize, channelFactory);
}

return GrpcTransportChannel.create(outerChannel);
}

Expand Down Expand Up @@ -529,10 +524,9 @@ public int getPoolSize() {
*/
public Builder setPoolSize(int poolSize) {
Preconditions.checkArgument(poolSize > 0, "Pool size must be positive");
Preconditions.checkArgument(
poolSize <= MAX_POOL_SIZE, "Pool size must be less than %d", MAX_POOL_SIZE);
this.poolSize = poolSize;
if (this.poolSize > MAX_POOL_SIZE) {
this.poolSize = MAX_POOL_SIZE;
}
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@
class RefreshingManagedChannel extends ManagedChannel {
Comment thread
igorbernstein2 marked this conversation as resolved.
private static final Logger LOG = Logger.getLogger(RefreshingManagedChannel.class.getName());
// refresh every 50 minutes with 15% jitter for a range of 42.5min to 57.5min
private static Duration refreshPeriod = Duration.ofMinutes(50);
private static double jitterPercentage = 0.15;
private static final Duration refreshPeriod = Duration.ofMinutes(50);
private static final double jitterPercentage = 0.15;
private volatile SafeShutdownManagedChannel delegate;
private final ReadWriteLock lock;
private final ChannelFactory channelFactory;
private final ScheduledExecutorService scheduledExecutorService;
private ScheduledFuture<?> nextScheduledRefresh;
private volatile ScheduledFuture<?> nextScheduledRefresh;

RefreshingManagedChannel(
ChannelFactory channelFactory, ScheduledExecutorService scheduledExecutorService)
Expand Down Expand Up @@ -96,7 +96,8 @@ private void refreshChannel() {
// When shutdown completes and releases the read lock and this thread acquires the write lock.
// This thread should not continue because the channel has shutdown. This check ensures that
// this thread terminates without swapping the channel and do not schedule the next refresh.
if (Thread.interrupted()) {
if (Thread.currentThread().isInterrupted()) {
newChannel.shutdownNow();
return;
Comment thread
tonytanger marked this conversation as resolved.
}
delegate = newChannel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
* <p>Package-private for internal use.
*/
class SafeShutdownManagedChannel extends ManagedChannel {
private ManagedChannel delegate;
private AtomicInteger outstandingCalls = new AtomicInteger(0);
private final ManagedChannel delegate;
private final AtomicInteger outstandingCalls = new AtomicInteger(0);
private volatile boolean isShutdownSafely = false;

SafeShutdownManagedChannel(ManagedChannel managedChannel) {
Expand Down Expand Up @@ -106,6 +106,16 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE
return delegate.awaitTermination(timeout, unit);
}

/**
* Decrement outstanding call counter and shutdown if there are no more outstanding calls and
* {@link SafeShutdownManagedChannel#shutdownSafely()} has been invoked
*/
private void onClientCallClose() {
if (outstandingCalls.decrementAndGet() == 0 && isShutdownSafely) {
shutdownSafely();
}
}

/** Listener that's responsible for decrementing outstandingCalls when the call closes */
private class DecrementOutstandingCalls<RespT> extends SimpleForwardingClientCallListener<RespT> {
DecrementOutstandingCalls(Listener<RespT> delegate) {
Expand All @@ -118,26 +128,21 @@ public void onClose(Status status, Metadata trailers) {
try {
super.onClose(status, trailers);
} finally {
if (outstandingCalls.decrementAndGet() == 0 && isShutdownSafely) {
shutdownSafely();
}
onClientCallClose();
}
}
}

/**
* Wrap client call to decrement outstandingCalls and shutdown channel if necessary when it
* completes
*
* @param call call to be wrapped
* @return the wrapped call
*/
private <ReqT, RespT> ClientCall<ReqT, RespT> clientCallWrapper(ClientCall<ReqT, RespT> call) {
return new SimpleForwardingClientCall<ReqT, RespT>(call) {
public void start(Listener<RespT> responseListener, Metadata headers) {
super.start(new DecrementOutstandingCalls<>(responseListener), headers);
}
};
/** To wrap around delegate to hook in {@link DecrementOutstandingCalls} */
private class ClientCallProxy<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> {
ClientCallProxy(ClientCall<ReqT, RespT> delegate) {
super(delegate);
}

@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
super.start(new DecrementOutstandingCalls<>(responseListener), headers);
}
}

/**
Expand All @@ -149,10 +154,10 @@ public void start(Listener<RespT> responseListener, Metadata headers) {
@Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
// increment after client call in case newCall throws an exception
Preconditions.checkState(!isShutdownSafely);
// increment after client call in case newCall throws an exception
ClientCall<RequestT, ResponseT> clientCall =
clientCallWrapper(delegate.newCall(methodDescriptor, callOptions));
new ClientCallProxy<>(delegate.newCall(methodDescriptor, callOptions));
outstandingCalls.incrementAndGet();
return clientCall;
}
Expand Down
35 changes: 19 additions & 16 deletions gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public void testRoundRobin() throws IOException {
Mockito.when(sub1.authority()).thenReturn("myAuth");

ArrayList<ManagedChannel> channels = Lists.newArrayList(sub1, sub2);
ChannelPool pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(sub1, sub2)));
ChannelPool pool = ChannelPool.create(channels.size(), new FakeChannelFactory(channels));

verifyTargetChannel(pool, channels, sub1);
verifyTargetChannel(pool, channels, sub2);
Expand Down Expand Up @@ -167,7 +167,7 @@ public void run() {

// Test channelPrimer is called same number of times as poolSize if executorService is set to null
@Test
public void channelPrimerShouldBeCalledOnce() throws IOException {
public void channelPrimerShouldCallPoolConstruction() throws IOException {
ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class);
ManagedChannel channel1 = Mockito.mock(ManagedChannel.class);
ManagedChannel channel2 = Mockito.mock(ManagedChannel.class);
Expand All @@ -190,21 +190,24 @@ public void channelPrimerIsCalledPeriodically() throws IOException, InterruptedE

ScheduledExecutorService scheduledExecutorService =
Mockito.mock(ScheduledExecutorService.class);
Mockito.when(
scheduledExecutorService.schedule(
Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)))
.thenAnswer(
new Answer() {
public Object answer(InvocationOnMock invocation) {
channelRefreshers.add((Runnable) invocation.getArgument(0));
return Mockito.mock(ScheduledFuture.class);
}
});

ChannelPool.createRefreshing(
1,
new FakeChannelFactory(Arrays.asList(channel1, channel2, channel3), mockChannelPrimer),
scheduledExecutorService);
Answer extractChannelRefresher =
new Answer() {
public Object answer(InvocationOnMock invocation) {
channelRefreshers.add((Runnable) invocation.getArgument(0));
return Mockito.mock(ScheduledFuture.class);
}
};

Mockito.doAnswer(extractChannelRefresher)
.when(scheduledExecutorService)
.schedule(
Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS));

FakeChannelFactory channelFactory =
new FakeChannelFactory(Arrays.asList(channel1, channel2, channel3), mockChannelPrimer);

ChannelPool.createRefreshing(1, channelFactory, scheduledExecutorService);
// 1 call during the creation
Mockito.verify(mockChannelPrimer, Mockito.times(1))
.primeChannel(Mockito.any(ManagedChannel.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,16 +377,19 @@ public void testWithPrimeChannel() throws IOException {

ScheduledExecutorService scheduledExecutorService =
Mockito.mock(ScheduledExecutorService.class);
Mockito.when(
scheduledExecutorService.schedule(
Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)))
.thenAnswer(
new Answer() {
public Object answer(InvocationOnMock invocation) {
channelRefreshers.add((Runnable) invocation.getArgument(0));
return Mockito.mock(ScheduledFuture.class);
}
});

Answer extractChannelRefresher =
new Answer() {
public Object answer(InvocationOnMock invocation) {
channelRefreshers.add((Runnable) invocation.getArgument(0));
return Mockito.mock(ScheduledFuture.class);
}
};

Mockito.doAnswer(extractChannelRefresher)
.when(scheduledExecutorService)
.schedule(
Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS));

InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
Expand Down
Loading