16
16
17
17
package com .google .cloud .spanner .spi .v1 ;
18
18
19
+ import static com .google .common .truth .Truth .assertThat ;
19
20
import static org .hamcrest .CoreMatchers .equalTo ;
20
21
import static org .hamcrest .CoreMatchers .is ;
21
22
import static org .hamcrest .MatcherAssert .assertThat ;
22
23
23
24
import com .google .api .core .ApiFunction ;
24
- import com .google .cloud .NoCredentials ;
25
+ import com .google .auth .oauth2 .AccessToken ;
26
+ import com .google .auth .oauth2 .OAuth2Credentials ;
25
27
import com .google .cloud .spanner .DatabaseAdminClient ;
26
28
import com .google .cloud .spanner .DatabaseClient ;
27
29
import com .google .cloud .spanner .DatabaseId ;
31
33
import com .google .cloud .spanner .ResultSet ;
32
34
import com .google .cloud .spanner .Spanner ;
33
35
import com .google .cloud .spanner .SpannerOptions ;
36
+ import com .google .cloud .spanner .SpannerOptions .CallCredentialsProvider ;
34
37
import com .google .cloud .spanner .Statement ;
35
38
import com .google .cloud .spanner .admin .database .v1 .MockDatabaseAdminImpl ;
36
39
import com .google .cloud .spanner .admin .instance .v1 .MockInstanceAdminImpl ;
40
+ import com .google .cloud .spanner .spi .v1 .SpannerRpc .Option ;
37
41
import com .google .common .base .Stopwatch ;
38
42
import com .google .protobuf .ListValue ;
39
43
import com .google .spanner .admin .database .v1 .Database ;
45
49
import com .google .spanner .v1 .StructType ;
46
50
import com .google .spanner .v1 .StructType .Field ;
47
51
import com .google .spanner .v1 .TypeCode ;
52
+ import io .grpc .CallCredentials ;
53
+ import io .grpc .Context ;
54
+ import io .grpc .Contexts ;
48
55
import io .grpc .ManagedChannelBuilder ;
56
+ import io .grpc .Metadata ;
57
+ import io .grpc .Metadata .Key ;
49
58
import io .grpc .Server ;
59
+ import io .grpc .ServerCall ;
60
+ import io .grpc .ServerCallHandler ;
61
+ import io .grpc .ServerInterceptor ;
62
+ import io .grpc .auth .MoreCallCredentials ;
50
63
import io .grpc .netty .shaded .io .grpc .netty .NettyServerBuilder ;
51
64
import java .io .IOException ;
52
65
import java .net .InetSocketAddress ;
53
66
import java .util .ArrayList ;
67
+ import java .util .HashMap ;
54
68
import java .util .List ;
69
+ import java .util .Map ;
55
70
import java .util .concurrent .TimeUnit ;
56
71
import java .util .regex .Pattern ;
57
72
import org .junit .After ;
@@ -91,11 +106,27 @@ public class GapicSpannerRpcTest {
91
106
.build ())
92
107
.setMetadata (SELECT1AND2_METADATA )
93
108
.build ();
109
+ private static final String STATIC_OAUTH_TOKEN = "STATIC_TEST_OAUTH_TOKEN" ;
110
+ private static final String VARIABLE_OAUTH_TOKEN = "VARIABLE_TEST_OAUTH_TOKEN" ;
111
+ private static final OAuth2Credentials STATIC_CREDENTIALS =
112
+ OAuth2Credentials .create (
113
+ new AccessToken (
114
+ STATIC_OAUTH_TOKEN ,
115
+ new java .util .Date (
116
+ System .currentTimeMillis () + TimeUnit .MILLISECONDS .convert (1L , TimeUnit .DAYS ))));
117
+ private static final OAuth2Credentials VARIABLE_CREDENTIALS =
118
+ OAuth2Credentials .create (
119
+ new AccessToken (
120
+ VARIABLE_OAUTH_TOKEN ,
121
+ new java .util .Date (
122
+ System .currentTimeMillis () + TimeUnit .MILLISECONDS .convert (1L , TimeUnit .DAYS ))));
123
+
94
124
private MockSpannerServiceImpl mockSpanner ;
95
125
private MockInstanceAdminImpl mockInstanceAdmin ;
96
126
private MockDatabaseAdminImpl mockDatabaseAdmin ;
97
127
private Server server ;
98
128
private InetSocketAddress address ;
129
+ private final Map <SpannerRpc .Option , Object > optionsMap = new HashMap <>();
99
130
100
131
@ Before
101
132
public void startServer () throws IOException {
@@ -111,8 +142,24 @@ public void startServer() throws IOException {
111
142
.addService (mockSpanner )
112
143
.addService (mockInstanceAdmin )
113
144
.addService (mockDatabaseAdmin )
145
+ // Add a server interceptor that will check that we receive the variable OAuth token
146
+ // from the CallCredentials, and not the one set as static credentials.
147
+ .intercept (
148
+ new ServerInterceptor () {
149
+ @ Override
150
+ public <ReqT , RespT > ServerCall .Listener <ReqT > interceptCall (
151
+ ServerCall <ReqT , RespT > call ,
152
+ Metadata headers ,
153
+ ServerCallHandler <ReqT , RespT > next ) {
154
+ String auth =
155
+ headers .get (Key .of ("authorization" , Metadata .ASCII_STRING_MARSHALLER ));
156
+ assertThat (auth ).isEqualTo ("Bearer " + VARIABLE_OAUTH_TOKEN );
157
+ return Contexts .interceptCall (Context .current (), call , headers , next );
158
+ }
159
+ })
114
160
.build ()
115
161
.start ();
162
+ optionsMap .put (Option .CHANNEL_HINT , Long .valueOf (1L ));
116
163
}
117
164
118
165
@ After
@@ -229,6 +276,55 @@ && getNumberOfThreadsWithName(SPANNER_THREAD_NAME, false)
229
276
assertThat (getNumberOfThreadsWithName (SPANNER_THREAD_NAME , true ), is (equalTo (0 )));
230
277
}
231
278
279
+ @ Test
280
+ public void testCallCredentialsProviderPreferenceAboveCredentials () {
281
+ SpannerOptions options =
282
+ SpannerOptions .newBuilder ()
283
+ .setCredentials (STATIC_CREDENTIALS )
284
+ .setCallCredentialsProvider (
285
+ new CallCredentialsProvider () {
286
+ @ Override
287
+ public CallCredentials getCallCredentials () {
288
+ return MoreCallCredentials .from (VARIABLE_CREDENTIALS );
289
+ }
290
+ })
291
+ .build ();
292
+ GapicSpannerRpc rpc = new GapicSpannerRpc (options );
293
+ // GoogleAuthLibraryCallCredentials doesn't implement equals, so we can only check for the
294
+ // existence.
295
+ assertThat (rpc .newCallContext (optionsMap , "/some/resource" ).getCallOptions ().getCredentials ())
296
+ .isNotNull ();
297
+ rpc .shutdown ();
298
+ }
299
+
300
+ @ Test
301
+ public void testCallCredentialsProviderReturnsNull () {
302
+ SpannerOptions options =
303
+ SpannerOptions .newBuilder ()
304
+ .setCredentials (STATIC_CREDENTIALS )
305
+ .setCallCredentialsProvider (
306
+ new CallCredentialsProvider () {
307
+ @ Override
308
+ public CallCredentials getCallCredentials () {
309
+ return null ;
310
+ }
311
+ })
312
+ .build ();
313
+ GapicSpannerRpc rpc = new GapicSpannerRpc (options );
314
+ assertThat (rpc .newCallContext (optionsMap , "/some/resource" ).getCallOptions ().getCredentials ())
315
+ .isNull ();
316
+ rpc .shutdown ();
317
+ }
318
+
319
+ @ Test
320
+ public void testNoCallCredentials () {
321
+ SpannerOptions options = SpannerOptions .newBuilder ().setCredentials (STATIC_CREDENTIALS ).build ();
322
+ GapicSpannerRpc rpc = new GapicSpannerRpc (options );
323
+ assertThat (rpc .newCallContext (optionsMap , "/some/resource" ).getCallOptions ().getCredentials ())
324
+ .isNull ();
325
+ rpc .shutdown ();
326
+ }
327
+
232
328
@ SuppressWarnings ("rawtypes" )
233
329
private SpannerOptions createSpannerOptions () {
234
330
String endpoint = address .getHostString () + ":" + server .getPort ();
@@ -244,7 +340,17 @@ public ManagedChannelBuilder apply(ManagedChannelBuilder input) {
244
340
}
245
341
})
246
342
.setHost ("http://" + endpoint )
247
- .setCredentials (NoCredentials .getInstance ())
343
+ // Set static credentials that will return the static OAuth test token.
344
+ .setCredentials (STATIC_CREDENTIALS )
345
+ // Also set a CallCredentialsProvider. These credentials should take precedence above
346
+ // the static credentials.
347
+ .setCallCredentialsProvider (
348
+ new CallCredentialsProvider () {
349
+ @ Override
350
+ public CallCredentials getCallCredentials () {
351
+ return MoreCallCredentials .from (VARIABLE_CREDENTIALS );
352
+ }
353
+ })
248
354
.build ();
249
355
}
250
356
0 commit comments