sift_py.grpc.transport_test
1# ruff: noqa: N802 2 3import re 4from concurrent import futures 5from contextlib import contextmanager 6from typing import Any, Callable, Iterator, cast 7 8import grpc 9import pytest 10from pytest_mock import MockFixture, MockType 11from sift.data.v2.data_pb2 import GetDataRequest, GetDataResponse 12from sift.data.v2.data_pb2_grpc import ( 13 DataServiceServicer, 14 DataServiceStub, 15 add_DataServiceServicer_to_server, 16) 17 18from sift_py._internal.test_util.server_interceptor import ServerInterceptor 19from sift_py.grpc.transport import SiftChannelConfig, use_sift_channel 20 21 22class DataService(DataServiceServicer): 23 def GetData(self, request: GetDataRequest, context: grpc.ServicerContext): 24 return GetDataResponse(next_page_token="next-page-token") 25 26 27class AuthInterceptor(ServerInterceptor): 28 AUTH_REGEX = re.compile(r"^Bearer (.+)$") 29 30 def intercept( 31 self, 32 method: Callable, 33 request_or_iterator: Any, 34 context: grpc.ServicerContext, 35 method_name: str, 36 ) -> Any: 37 authenticated = False 38 for metadata in context.invocation_metadata(): 39 if metadata.key == "authorization": 40 auth = self.__class__.AUTH_REGEX.match(metadata.value) 41 42 if auth is not None and len(auth.group(1)) > 0: 43 authenticated = True 44 45 break 46 47 if authenticated: 48 return method(request_or_iterator, context) 49 else: 50 context.set_code(grpc.StatusCode.UNAUTHENTICATED) 51 context.set_details("Invalid or missing API key") 52 raise 53 54 55class ForceFailInterceptor(ServerInterceptor): 56 """ 57 Force RPC to fail a few times before letting it pass. 58 59 `failed_attempts`: Count of how many times failed 60 `expected_num_fails`: How many times you want call to fail 61 """ 62 63 failed_attempts: int 64 expected_num_fails: int 65 failure_code: grpc.StatusCode 66 67 def __init__( 68 self, expected_num_fails: int, failure_code: grpc.StatusCode = grpc.StatusCode.UNKNOWN 69 ): 70 self.expected_num_fails = expected_num_fails 71 self.failed_attempts = 0 72 self.failure_code = failure_code 73 super().__init__() 74 75 def intercept( 76 self, 77 method: Callable, 78 request_or_iterator: Any, 79 context: grpc.ServicerContext, 80 method_name: str, 81 ) -> Any: 82 if self.failed_attempts < self.expected_num_fails: 83 self.failed_attempts += 1 84 context.set_code(self.failure_code) 85 context.set_details("something unknown happened") 86 raise 87 88 return method(request_or_iterator, context) 89 90 91@contextmanager 92def server_spy(mocker: MockFixture, *interceptors: ServerInterceptor) -> Iterator[MockType]: 93 server = grpc.server( 94 thread_pool=futures.ThreadPoolExecutor(max_workers=1), interceptors=list(interceptors) 95 ) 96 97 data_service = DataService() 98 spy = mocker.spy(data_service, "GetData") 99 100 add_DataServiceServicer_to_server(data_service, server) 101 server.add_insecure_port("[::]:50052") 102 server.start() 103 try: 104 yield spy 105 finally: 106 server.stop(None) 107 server.wait_for_termination() 108 109 110def test_sift_channel(mocker: MockFixture): 111 with server_spy(mocker, AuthInterceptor()) as get_data_spy: 112 sift_channel_config_a: SiftChannelConfig = { 113 "uri": "localhost:50052", 114 "apikey": "", 115 "use_ssl": False, 116 } 117 118 with use_sift_channel(sift_channel_config_a) as channel: 119 with pytest.raises(grpc.RpcError, match="UNAUTHENTICATED"): 120 stub = DataServiceStub(channel) 121 _ = cast(GetDataResponse, stub.GetData(GetDataRequest())) 122 123 get_data_spy.assert_not_called() 124 125 sift_channel_config_b: SiftChannelConfig = { 126 "uri": "localhost:50052", 127 "apikey": "some-token", 128 "use_ssl": False, 129 } 130 131 with use_sift_channel(sift_channel_config_b) as channel: 132 stub = DataServiceStub(channel) 133 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 134 assert res.next_page_token == "next-page-token" 135 get_data_spy.assert_called_once() 136 137 force_fail_interceptor = ForceFailInterceptor(4) 138 with server_spy(mocker, AuthInterceptor(), force_fail_interceptor) as get_data_spy: 139 sift_channel_config_c: SiftChannelConfig = { 140 "uri": "localhost:50052", 141 "apikey": "some-token", 142 "use_ssl": False, 143 } 144 145 with use_sift_channel(sift_channel_config_c) as channel: 146 stub = DataServiceStub(channel) 147 # This will attempt 5 times: fail 4 times, succeed on 5th 148 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 149 assert res.next_page_token == "next-page-token" 150 get_data_spy.assert_called_once() 151 152 # fail 4 times, pass the 5th attempt 153 assert force_fail_interceptor.failed_attempts == 4 154 155 # Now we're going to fail beyond the max retry attempts 156 157 force_fail_interceptor_max = ForceFailInterceptor(7) 158 with server_spy(mocker, AuthInterceptor(), force_fail_interceptor_max) as get_data_spy: 159 sift_channel_config_d: SiftChannelConfig = { 160 "uri": "localhost:50052", 161 "apikey": "some-token", 162 "use_ssl": False, 163 } 164 165 with use_sift_channel(sift_channel_config_d) as channel: 166 stub = DataServiceStub(channel) 167 168 # This will go beyond the max number of attempts 169 with pytest.raises(Exception): 170 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 171 172 get_data_spy.assert_not_called() 173 174 # All attempts failed 175 assert force_fail_interceptor_max.failed_attempts == 5 176 177 178def test_internal_error_retry(mocker: MockFixture): 179 force_fail_interceptor = ForceFailInterceptor(4, failure_code=grpc.StatusCode.INTERNAL) 180 with server_spy(mocker, AuthInterceptor(), force_fail_interceptor) as get_data_spy: 181 sift_channel_config_c: SiftChannelConfig = { 182 "uri": "localhost:50052", 183 "apikey": "some-token", 184 "use_ssl": False, 185 } 186 187 with use_sift_channel(sift_channel_config_c) as channel: 188 stub = DataServiceStub(channel) 189 # This will attempt 5 times: fail 4 times, succeed on 5th 190 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 191 assert res.next_page_token == "next-page-token" 192 get_data_spy.assert_called_once() 193 194 # fail 4 times, pass the 5th attempt 195 assert force_fail_interceptor.failed_attempts == 4 196 197 # Now we're going to fail beyond the max retry attempts 198 force_fail_interceptor_max = ForceFailInterceptor(7) 199 with server_spy(mocker, AuthInterceptor(), force_fail_interceptor_max) as get_data_spy: 200 sift_channel_config_d: SiftChannelConfig = { 201 "uri": "localhost:50052", 202 "apikey": "some-token", 203 "use_ssl": False, 204 } 205 206 with use_sift_channel(sift_channel_config_d) as channel: 207 stub = DataServiceStub(channel) 208 209 # This will go beyond the max number of attempts 210 with pytest.raises(Exception): 211 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 212 213 get_data_spy.assert_not_called() 214 215 # All attempts failed 216 assert force_fail_interceptor_max.failed_attempts == 5
class
DataService(sift.data.v2.data_pb2_grpc.DataServiceServicer):
23class DataService(DataServiceServicer): 24 def GetData(self, request: GetDataRequest, context: grpc.ServicerContext): 25 return GetDataResponse(next_page_token="next-page-token")
Missing associated documentation comment in .proto file.
28class AuthInterceptor(ServerInterceptor): 29 AUTH_REGEX = re.compile(r"^Bearer (.+)$") 30 31 def intercept( 32 self, 33 method: Callable, 34 request_or_iterator: Any, 35 context: grpc.ServicerContext, 36 method_name: str, 37 ) -> Any: 38 authenticated = False 39 for metadata in context.invocation_metadata(): 40 if metadata.key == "authorization": 41 auth = self.__class__.AUTH_REGEX.match(metadata.value) 42 43 if auth is not None and len(auth.group(1)) > 0: 44 authenticated = True 45 46 break 47 48 if authenticated: 49 return method(request_or_iterator, context) 50 else: 51 context.set_code(grpc.StatusCode.UNAUTHENTICATED) 52 context.set_details("Invalid or missing API key") 53 raise
Affords intercepting incoming RPCs on the service-side.
def
intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str) -> Any:
31 def intercept( 32 self, 33 method: Callable, 34 request_or_iterator: Any, 35 context: grpc.ServicerContext, 36 method_name: str, 37 ) -> Any: 38 authenticated = False 39 for metadata in context.invocation_metadata(): 40 if metadata.key == "authorization": 41 auth = self.__class__.AUTH_REGEX.match(metadata.value) 42 43 if auth is not None and len(auth.group(1)) > 0: 44 authenticated = True 45 46 break 47 48 if authenticated: 49 return method(request_or_iterator, context) 50 else: 51 context.set_code(grpc.StatusCode.UNAUTHENTICATED) 52 context.set_details("Invalid or missing API key") 53 raise
56class ForceFailInterceptor(ServerInterceptor): 57 """ 58 Force RPC to fail a few times before letting it pass. 59 60 `failed_attempts`: Count of how many times failed 61 `expected_num_fails`: How many times you want call to fail 62 """ 63 64 failed_attempts: int 65 expected_num_fails: int 66 failure_code: grpc.StatusCode 67 68 def __init__( 69 self, expected_num_fails: int, failure_code: grpc.StatusCode = grpc.StatusCode.UNKNOWN 70 ): 71 self.expected_num_fails = expected_num_fails 72 self.failed_attempts = 0 73 self.failure_code = failure_code 74 super().__init__() 75 76 def intercept( 77 self, 78 method: Callable, 79 request_or_iterator: Any, 80 context: grpc.ServicerContext, 81 method_name: str, 82 ) -> Any: 83 if self.failed_attempts < self.expected_num_fails: 84 self.failed_attempts += 1 85 context.set_code(self.failure_code) 86 context.set_details("something unknown happened") 87 raise 88 89 return method(request_or_iterator, context)
Force RPC to fail a few times before letting it pass.
failed_attempts
: Count of how many times failed
expected_num_fails
: How many times you want call to fail
ForceFailInterceptor( expected_num_fails: int, failure_code: grpc.StatusCode = <StatusCode.UNKNOWN: (2, 'unknown')>)
def
intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str) -> Any:
76 def intercept( 77 self, 78 method: Callable, 79 request_or_iterator: Any, 80 context: grpc.ServicerContext, 81 method_name: str, 82 ) -> Any: 83 if self.failed_attempts < self.expected_num_fails: 84 self.failed_attempts += 1 85 context.set_code(self.failure_code) 86 context.set_details("something unknown happened") 87 raise 88 89 return method(request_or_iterator, context)
@contextmanager
def
server_spy( mocker: pytest_mock.plugin.MockerFixture, *interceptors: sift_py._internal.test_util.server_interceptor.ServerInterceptor) -> Iterator[Union[unittest.mock.MagicMock, unittest.mock.AsyncMock, unittest.mock.NonCallableMagicMock]]:
92@contextmanager 93def server_spy(mocker: MockFixture, *interceptors: ServerInterceptor) -> Iterator[MockType]: 94 server = grpc.server( 95 thread_pool=futures.ThreadPoolExecutor(max_workers=1), interceptors=list(interceptors) 96 ) 97 98 data_service = DataService() 99 spy = mocker.spy(data_service, "GetData") 100 101 add_DataServiceServicer_to_server(data_service, server) 102 server.add_insecure_port("[::]:50052") 103 server.start() 104 try: 105 yield spy 106 finally: 107 server.stop(None) 108 server.wait_for_termination()
def
test_sift_channel(mocker: pytest_mock.plugin.MockerFixture):
111def test_sift_channel(mocker: MockFixture): 112 with server_spy(mocker, AuthInterceptor()) as get_data_spy: 113 sift_channel_config_a: SiftChannelConfig = { 114 "uri": "localhost:50052", 115 "apikey": "", 116 "use_ssl": False, 117 } 118 119 with use_sift_channel(sift_channel_config_a) as channel: 120 with pytest.raises(grpc.RpcError, match="UNAUTHENTICATED"): 121 stub = DataServiceStub(channel) 122 _ = cast(GetDataResponse, stub.GetData(GetDataRequest())) 123 124 get_data_spy.assert_not_called() 125 126 sift_channel_config_b: SiftChannelConfig = { 127 "uri": "localhost:50052", 128 "apikey": "some-token", 129 "use_ssl": False, 130 } 131 132 with use_sift_channel(sift_channel_config_b) as channel: 133 stub = DataServiceStub(channel) 134 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 135 assert res.next_page_token == "next-page-token" 136 get_data_spy.assert_called_once() 137 138 force_fail_interceptor = ForceFailInterceptor(4) 139 with server_spy(mocker, AuthInterceptor(), force_fail_interceptor) as get_data_spy: 140 sift_channel_config_c: SiftChannelConfig = { 141 "uri": "localhost:50052", 142 "apikey": "some-token", 143 "use_ssl": False, 144 } 145 146 with use_sift_channel(sift_channel_config_c) as channel: 147 stub = DataServiceStub(channel) 148 # This will attempt 5 times: fail 4 times, succeed on 5th 149 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 150 assert res.next_page_token == "next-page-token" 151 get_data_spy.assert_called_once() 152 153 # fail 4 times, pass the 5th attempt 154 assert force_fail_interceptor.failed_attempts == 4 155 156 # Now we're going to fail beyond the max retry attempts 157 158 force_fail_interceptor_max = ForceFailInterceptor(7) 159 with server_spy(mocker, AuthInterceptor(), force_fail_interceptor_max) as get_data_spy: 160 sift_channel_config_d: SiftChannelConfig = { 161 "uri": "localhost:50052", 162 "apikey": "some-token", 163 "use_ssl": False, 164 } 165 166 with use_sift_channel(sift_channel_config_d) as channel: 167 stub = DataServiceStub(channel) 168 169 # This will go beyond the max number of attempts 170 with pytest.raises(Exception): 171 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 172 173 get_data_spy.assert_not_called() 174 175 # All attempts failed 176 assert force_fail_interceptor_max.failed_attempts == 5
def
test_internal_error_retry(mocker: pytest_mock.plugin.MockerFixture):
179def test_internal_error_retry(mocker: MockFixture): 180 force_fail_interceptor = ForceFailInterceptor(4, failure_code=grpc.StatusCode.INTERNAL) 181 with server_spy(mocker, AuthInterceptor(), force_fail_interceptor) as get_data_spy: 182 sift_channel_config_c: SiftChannelConfig = { 183 "uri": "localhost:50052", 184 "apikey": "some-token", 185 "use_ssl": False, 186 } 187 188 with use_sift_channel(sift_channel_config_c) as channel: 189 stub = DataServiceStub(channel) 190 # This will attempt 5 times: fail 4 times, succeed on 5th 191 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 192 assert res.next_page_token == "next-page-token" 193 get_data_spy.assert_called_once() 194 195 # fail 4 times, pass the 5th attempt 196 assert force_fail_interceptor.failed_attempts == 4 197 198 # Now we're going to fail beyond the max retry attempts 199 force_fail_interceptor_max = ForceFailInterceptor(7) 200 with server_spy(mocker, AuthInterceptor(), force_fail_interceptor_max) as get_data_spy: 201 sift_channel_config_d: SiftChannelConfig = { 202 "uri": "localhost:50052", 203 "apikey": "some-token", 204 "use_ssl": False, 205 } 206 207 with use_sift_channel(sift_channel_config_d) as channel: 208 stub = DataServiceStub(channel) 209 210 # This will go beyond the max number of attempts 211 with pytest.raises(Exception): 212 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 213 214 get_data_spy.assert_not_called() 215 216 # All attempts failed 217 assert force_fail_interceptor_max.failed_attempts == 5