sift_py.grpc.transport_test
1# ruff: noqa: N802 2 3import re 4from concurrent import futures 5from contextlib import contextmanager 6from typing import Any, Callable, Iterator, cast 7 8import grpc 9import pytest 10from pytest_mock import MockFixture, MockType 11from sift.data.v2.data_pb2 import GetDataRequest, GetDataResponse 12from sift.data.v2.data_pb2_grpc import ( 13 DataServiceServicer, 14 DataServiceStub, 15 add_DataServiceServicer_to_server, 16) 17 18from sift_py._internal.test_util.server_interceptor import ServerInterceptor 19from sift_py.grpc.transport import SiftChannelConfig, use_sift_channel 20 21 22class DataService(DataServiceServicer): 23 def GetData(self, request: GetDataRequest, context: grpc.ServicerContext): 24 return GetDataResponse(next_page_token="next-page-token") 25 26 27class AuthInterceptor(ServerInterceptor): 28 AUTH_REGEX = re.compile(r"^Bearer (.+)$") 29 30 def intercept( 31 self, 32 method: Callable, 33 request_or_iterator: Any, 34 context: grpc.ServicerContext, 35 method_name: str, 36 ) -> Any: 37 authenticated = False 38 for metadata in context.invocation_metadata(): 39 if metadata.key == "authorization": 40 auth = self.__class__.AUTH_REGEX.match(metadata.value) 41 42 if auth is not None and len(auth.group(1)) > 0: 43 authenticated = True 44 45 break 46 47 if authenticated: 48 return method(request_or_iterator, context) 49 else: 50 context.set_code(grpc.StatusCode.UNAUTHENTICATED) 51 context.set_details("Invalid or missing API key") 52 raise 53 54 55class ForceFailInterceptor(ServerInterceptor): 56 """ 57 Force RPC to fail a few times before letting it pass. 58 59 `failed_attempts`: Count of how many times failed 60 `expected_num_fails`: How many times you want call to fail 61 """ 62 63 failed_attempts: int 64 expected_num_fails: int 65 66 def __init__(self, expected_num_fails: int): 67 self.expected_num_fails = expected_num_fails 68 self.failed_attempts = 0 69 super().__init__() 70 71 def intercept( 72 self, 73 method: Callable, 74 request_or_iterator: Any, 75 context: grpc.ServicerContext, 76 method_name: str, 77 ) -> Any: 78 if self.failed_attempts < self.expected_num_fails: 79 self.failed_attempts += 1 80 context.set_code(grpc.StatusCode.UNKNOWN) 81 context.set_details("something unknown happened") 82 raise 83 84 return method(request_or_iterator, context) 85 86 87def test_sift_channel(mocker: MockFixture): 88 @contextmanager 89 def test_server_spy(*interceptors: ServerInterceptor) -> Iterator[MockType]: 90 server = grpc.server( 91 thread_pool=futures.ThreadPoolExecutor(max_workers=1), interceptors=list(interceptors) 92 ) 93 94 data_service = DataService() 95 spy = mocker.spy(data_service, "GetData") 96 97 add_DataServiceServicer_to_server(data_service, server) 98 server.add_insecure_port("[::]:50052") 99 server.start() 100 try: 101 yield spy 102 finally: 103 server.stop(None) 104 server.wait_for_termination() 105 106 with test_server_spy(AuthInterceptor()) as get_data_spy: 107 sift_channel_config_a: SiftChannelConfig = { 108 "uri": "localhost:50052", 109 "apikey": "", 110 "use_ssl": False, 111 } 112 113 with use_sift_channel(sift_channel_config_a) as channel: 114 with pytest.raises(grpc.RpcError, match="UNAUTHENTICATED"): 115 stub = DataServiceStub(channel) 116 _ = cast(GetDataResponse, stub.GetData(GetDataRequest())) 117 118 get_data_spy.assert_not_called() 119 120 sift_channel_config_b: SiftChannelConfig = { 121 "uri": "localhost:50052", 122 "apikey": "some-token", 123 "use_ssl": False, 124 } 125 126 with use_sift_channel(sift_channel_config_b) as channel: 127 stub = DataServiceStub(channel) 128 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 129 assert res.next_page_token == "next-page-token" 130 get_data_spy.assert_called_once() 131 132 force_fail_interceptor = ForceFailInterceptor(4) 133 with test_server_spy(AuthInterceptor(), force_fail_interceptor) as get_data_spy: 134 sift_channel_config_c: SiftChannelConfig = { 135 "uri": "localhost:50052", 136 "apikey": "some-token", 137 "use_ssl": False, 138 } 139 140 with use_sift_channel(sift_channel_config_c) as channel: 141 stub = DataServiceStub(channel) 142 # This will attempt 5 times: fail 4 times, succeed on 5th 143 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 144 assert res.next_page_token == "next-page-token" 145 get_data_spy.assert_called_once() 146 147 # fail 4 times, pass the 5th attempt 148 assert force_fail_interceptor.failed_attempts == 4 149 150 # Now we're going to fail beyond the max retry attempts 151 152 force_fail_interceptor_max = ForceFailInterceptor(7) 153 with test_server_spy(AuthInterceptor(), force_fail_interceptor_max) as get_data_spy: 154 sift_channel_config_d: SiftChannelConfig = { 155 "uri": "localhost:50052", 156 "apikey": "some-token", 157 "use_ssl": False, 158 } 159 160 with use_sift_channel(sift_channel_config_d) as channel: 161 stub = DataServiceStub(channel) 162 163 # This will go beyond the max number of attempts 164 with pytest.raises(Exception): 165 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 166 167 get_data_spy.assert_not_called() 168 169 # All attempts failed 170 assert force_fail_interceptor_max.failed_attempts == 5
class
DataService(sift.data.v2.data_pb2_grpc.DataServiceServicer):
23class DataService(DataServiceServicer): 24 def GetData(self, request: GetDataRequest, context: grpc.ServicerContext): 25 return GetDataResponse(next_page_token="next-page-token")
Missing associated documentation comment in .proto file.
28class AuthInterceptor(ServerInterceptor): 29 AUTH_REGEX = re.compile(r"^Bearer (.+)$") 30 31 def intercept( 32 self, 33 method: Callable, 34 request_or_iterator: Any, 35 context: grpc.ServicerContext, 36 method_name: str, 37 ) -> Any: 38 authenticated = False 39 for metadata in context.invocation_metadata(): 40 if metadata.key == "authorization": 41 auth = self.__class__.AUTH_REGEX.match(metadata.value) 42 43 if auth is not None and len(auth.group(1)) > 0: 44 authenticated = True 45 46 break 47 48 if authenticated: 49 return method(request_or_iterator, context) 50 else: 51 context.set_code(grpc.StatusCode.UNAUTHENTICATED) 52 context.set_details("Invalid or missing API key") 53 raise
Affords intercepting incoming RPCs on the service-side.
def
intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str) -> Any:
31 def intercept( 32 self, 33 method: Callable, 34 request_or_iterator: Any, 35 context: grpc.ServicerContext, 36 method_name: str, 37 ) -> Any: 38 authenticated = False 39 for metadata in context.invocation_metadata(): 40 if metadata.key == "authorization": 41 auth = self.__class__.AUTH_REGEX.match(metadata.value) 42 43 if auth is not None and len(auth.group(1)) > 0: 44 authenticated = True 45 46 break 47 48 if authenticated: 49 return method(request_or_iterator, context) 50 else: 51 context.set_code(grpc.StatusCode.UNAUTHENTICATED) 52 context.set_details("Invalid or missing API key") 53 raise
56class ForceFailInterceptor(ServerInterceptor): 57 """ 58 Force RPC to fail a few times before letting it pass. 59 60 `failed_attempts`: Count of how many times failed 61 `expected_num_fails`: How many times you want call to fail 62 """ 63 64 failed_attempts: int 65 expected_num_fails: int 66 67 def __init__(self, expected_num_fails: int): 68 self.expected_num_fails = expected_num_fails 69 self.failed_attempts = 0 70 super().__init__() 71 72 def intercept( 73 self, 74 method: Callable, 75 request_or_iterator: Any, 76 context: grpc.ServicerContext, 77 method_name: str, 78 ) -> Any: 79 if self.failed_attempts < self.expected_num_fails: 80 self.failed_attempts += 1 81 context.set_code(grpc.StatusCode.UNKNOWN) 82 context.set_details("something unknown happened") 83 raise 84 85 return method(request_or_iterator, context)
Force RPC to fail a few times before letting it pass.
failed_attempts
: Count of how many times failed
expected_num_fails
: How many times you want call to fail
def
intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str) -> Any:
72 def intercept( 73 self, 74 method: Callable, 75 request_or_iterator: Any, 76 context: grpc.ServicerContext, 77 method_name: str, 78 ) -> Any: 79 if self.failed_attempts < self.expected_num_fails: 80 self.failed_attempts += 1 81 context.set_code(grpc.StatusCode.UNKNOWN) 82 context.set_details("something unknown happened") 83 raise 84 85 return method(request_or_iterator, context)
def
test_sift_channel(mocker: pytest_mock.plugin.MockerFixture):
88def test_sift_channel(mocker: MockFixture): 89 @contextmanager 90 def test_server_spy(*interceptors: ServerInterceptor) -> Iterator[MockType]: 91 server = grpc.server( 92 thread_pool=futures.ThreadPoolExecutor(max_workers=1), interceptors=list(interceptors) 93 ) 94 95 data_service = DataService() 96 spy = mocker.spy(data_service, "GetData") 97 98 add_DataServiceServicer_to_server(data_service, server) 99 server.add_insecure_port("[::]:50052") 100 server.start() 101 try: 102 yield spy 103 finally: 104 server.stop(None) 105 server.wait_for_termination() 106 107 with test_server_spy(AuthInterceptor()) as get_data_spy: 108 sift_channel_config_a: SiftChannelConfig = { 109 "uri": "localhost:50052", 110 "apikey": "", 111 "use_ssl": False, 112 } 113 114 with use_sift_channel(sift_channel_config_a) as channel: 115 with pytest.raises(grpc.RpcError, match="UNAUTHENTICATED"): 116 stub = DataServiceStub(channel) 117 _ = cast(GetDataResponse, stub.GetData(GetDataRequest())) 118 119 get_data_spy.assert_not_called() 120 121 sift_channel_config_b: SiftChannelConfig = { 122 "uri": "localhost:50052", 123 "apikey": "some-token", 124 "use_ssl": False, 125 } 126 127 with use_sift_channel(sift_channel_config_b) as channel: 128 stub = DataServiceStub(channel) 129 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 130 assert res.next_page_token == "next-page-token" 131 get_data_spy.assert_called_once() 132 133 force_fail_interceptor = ForceFailInterceptor(4) 134 with test_server_spy(AuthInterceptor(), force_fail_interceptor) as get_data_spy: 135 sift_channel_config_c: SiftChannelConfig = { 136 "uri": "localhost:50052", 137 "apikey": "some-token", 138 "use_ssl": False, 139 } 140 141 with use_sift_channel(sift_channel_config_c) as channel: 142 stub = DataServiceStub(channel) 143 # This will attempt 5 times: fail 4 times, succeed on 5th 144 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 145 assert res.next_page_token == "next-page-token" 146 get_data_spy.assert_called_once() 147 148 # fail 4 times, pass the 5th attempt 149 assert force_fail_interceptor.failed_attempts == 4 150 151 # Now we're going to fail beyond the max retry attempts 152 153 force_fail_interceptor_max = ForceFailInterceptor(7) 154 with test_server_spy(AuthInterceptor(), force_fail_interceptor_max) as get_data_spy: 155 sift_channel_config_d: SiftChannelConfig = { 156 "uri": "localhost:50052", 157 "apikey": "some-token", 158 "use_ssl": False, 159 } 160 161 with use_sift_channel(sift_channel_config_d) as channel: 162 stub = DataServiceStub(channel) 163 164 # This will go beyond the max number of attempts 165 with pytest.raises(Exception): 166 res = cast(GetDataResponse, stub.GetData(GetDataRequest())) 167 168 get_data_spy.assert_not_called() 169 170 # All attempts failed 171 assert force_fail_interceptor_max.failed_attempts == 5