sift_py.grpc.transport_test

  1# ruff: noqa: N802
  2
  3import re
  4from concurrent import futures
  5from contextlib import contextmanager
  6from typing import Any, Callable, Iterator, cast
  7
  8import grpc
  9import pytest
 10from pytest_mock import MockFixture, MockType
 11from sift.data.v2.data_pb2 import GetDataRequest, GetDataResponse
 12from sift.data.v2.data_pb2_grpc import (
 13    DataServiceServicer,
 14    DataServiceStub,
 15    add_DataServiceServicer_to_server,
 16)
 17
 18from sift_py._internal.test_util.server_interceptor import ServerInterceptor
 19from sift_py.grpc.transport import SiftChannelConfig, use_sift_channel
 20
 21
 22class DataService(DataServiceServicer):
 23    def GetData(self, request: GetDataRequest, context: grpc.ServicerContext):
 24        return GetDataResponse(next_page_token="next-page-token")
 25
 26
 27class AuthInterceptor(ServerInterceptor):
 28    AUTH_REGEX = re.compile(r"^Bearer (.+)$")
 29
 30    def intercept(
 31        self,
 32        method: Callable,
 33        request_or_iterator: Any,
 34        context: grpc.ServicerContext,
 35        method_name: str,
 36    ) -> Any:
 37        authenticated = False
 38        for metadata in context.invocation_metadata():
 39            if metadata.key == "authorization":
 40                auth = self.__class__.AUTH_REGEX.match(metadata.value)
 41
 42                if auth is not None and len(auth.group(1)) > 0:
 43                    authenticated = True
 44
 45                break
 46
 47        if authenticated:
 48            return method(request_or_iterator, context)
 49        else:
 50            context.set_code(grpc.StatusCode.UNAUTHENTICATED)
 51            context.set_details("Invalid or missing API key")
 52            raise
 53
 54
 55class ForceFailInterceptor(ServerInterceptor):
 56    """
 57    Force RPC to fail a few times before letting it pass.
 58
 59    `failed_attempts`: Count of how many times failed
 60    `expected_num_fails`: How many times you want call to fail
 61    """
 62
 63    failed_attempts: int
 64    expected_num_fails: int
 65    failure_code: grpc.StatusCode
 66
 67    def __init__(
 68        self, expected_num_fails: int, failure_code: grpc.StatusCode = grpc.StatusCode.UNKNOWN
 69    ):
 70        self.expected_num_fails = expected_num_fails
 71        self.failed_attempts = 0
 72        self.failure_code = failure_code
 73        super().__init__()
 74
 75    def intercept(
 76        self,
 77        method: Callable,
 78        request_or_iterator: Any,
 79        context: grpc.ServicerContext,
 80        method_name: str,
 81    ) -> Any:
 82        if self.failed_attempts < self.expected_num_fails:
 83            self.failed_attempts += 1
 84            context.set_code(self.failure_code)
 85            context.set_details("something unknown happened")
 86            raise
 87
 88        return method(request_or_iterator, context)
 89
 90
 91@contextmanager
 92def server_spy(mocker: MockFixture, *interceptors: ServerInterceptor) -> Iterator[MockType]:
 93    server = grpc.server(
 94        thread_pool=futures.ThreadPoolExecutor(max_workers=1), interceptors=list(interceptors)
 95    )
 96
 97    data_service = DataService()
 98    spy = mocker.spy(data_service, "GetData")
 99
100    add_DataServiceServicer_to_server(data_service, server)
101    server.add_insecure_port("[::]:50052")
102    server.start()
103    try:
104        yield spy
105    finally:
106        server.stop(None)
107        server.wait_for_termination()
108
109
110def test_sift_channel(mocker: MockFixture):
111    with server_spy(mocker, AuthInterceptor()) as get_data_spy:
112        sift_channel_config_a: SiftChannelConfig = {
113            "uri": "localhost:50052",
114            "apikey": "",
115            "use_ssl": False,
116        }
117
118        with use_sift_channel(sift_channel_config_a) as channel:
119            with pytest.raises(grpc.RpcError, match="UNAUTHENTICATED"):
120                stub = DataServiceStub(channel)
121                _ = cast(GetDataResponse, stub.GetData(GetDataRequest()))
122
123            get_data_spy.assert_not_called()
124
125        sift_channel_config_b: SiftChannelConfig = {
126            "uri": "localhost:50052",
127            "apikey": "some-token",
128            "use_ssl": False,
129        }
130
131        with use_sift_channel(sift_channel_config_b) as channel:
132            stub = DataServiceStub(channel)
133            res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
134            assert res.next_page_token == "next-page-token"
135            get_data_spy.assert_called_once()
136
137    force_fail_interceptor = ForceFailInterceptor(4)
138    with server_spy(mocker, AuthInterceptor(), force_fail_interceptor) as get_data_spy:
139        sift_channel_config_c: SiftChannelConfig = {
140            "uri": "localhost:50052",
141            "apikey": "some-token",
142            "use_ssl": False,
143        }
144
145        with use_sift_channel(sift_channel_config_c) as channel:
146            stub = DataServiceStub(channel)
147            # This will attempt 5 times: fail 4 times, succeed on 5th
148            res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
149            assert res.next_page_token == "next-page-token"
150            get_data_spy.assert_called_once()
151
152    # fail 4 times, pass the 5th attempt
153    assert force_fail_interceptor.failed_attempts == 4
154
155    # Now we're going to fail beyond the max retry attempts
156
157    force_fail_interceptor_max = ForceFailInterceptor(7)
158    with server_spy(mocker, AuthInterceptor(), force_fail_interceptor_max) as get_data_spy:
159        sift_channel_config_d: SiftChannelConfig = {
160            "uri": "localhost:50052",
161            "apikey": "some-token",
162            "use_ssl": False,
163        }
164
165        with use_sift_channel(sift_channel_config_d) as channel:
166            stub = DataServiceStub(channel)
167
168            # This will go beyond the max number of attempts
169            with pytest.raises(Exception):
170                res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
171
172            get_data_spy.assert_not_called()
173
174    # All attempts failed
175    assert force_fail_interceptor_max.failed_attempts == 5
176
177
178def test_internal_error_retry(mocker: MockFixture):
179    force_fail_interceptor = ForceFailInterceptor(4, failure_code=grpc.StatusCode.INTERNAL)
180    with server_spy(mocker, AuthInterceptor(), force_fail_interceptor) as get_data_spy:
181        sift_channel_config_c: SiftChannelConfig = {
182            "uri": "localhost:50052",
183            "apikey": "some-token",
184            "use_ssl": False,
185        }
186
187        with use_sift_channel(sift_channel_config_c) as channel:
188            stub = DataServiceStub(channel)
189            # This will attempt 5 times: fail 4 times, succeed on 5th
190            res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
191            assert res.next_page_token == "next-page-token"
192            get_data_spy.assert_called_once()
193
194    # fail 4 times, pass the 5th attempt
195    assert force_fail_interceptor.failed_attempts == 4
196
197    # Now we're going to fail beyond the max retry attempts
198    force_fail_interceptor_max = ForceFailInterceptor(7)
199    with server_spy(mocker, AuthInterceptor(), force_fail_interceptor_max) as get_data_spy:
200        sift_channel_config_d: SiftChannelConfig = {
201            "uri": "localhost:50052",
202            "apikey": "some-token",
203            "use_ssl": False,
204        }
205
206        with use_sift_channel(sift_channel_config_d) as channel:
207            stub = DataServiceStub(channel)
208
209            # This will go beyond the max number of attempts
210            with pytest.raises(Exception):
211                res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
212
213            get_data_spy.assert_not_called()
214
215    # All attempts failed
216    assert force_fail_interceptor_max.failed_attempts == 5
class DataService(sift.data.v2.data_pb2_grpc.DataServiceServicer):
23class DataService(DataServiceServicer):
24    def GetData(self, request: GetDataRequest, context: grpc.ServicerContext):
25        return GetDataResponse(next_page_token="next-page-token")

Missing associated documentation comment in .proto file.

def GetData( self, request: sift.data.v2.data_pb2.GetDataRequest, context: grpc.ServicerContext):
24    def GetData(self, request: GetDataRequest, context: grpc.ServicerContext):
25        return GetDataResponse(next_page_token="next-page-token")

Query data

28class AuthInterceptor(ServerInterceptor):
29    AUTH_REGEX = re.compile(r"^Bearer (.+)$")
30
31    def intercept(
32        self,
33        method: Callable,
34        request_or_iterator: Any,
35        context: grpc.ServicerContext,
36        method_name: str,
37    ) -> Any:
38        authenticated = False
39        for metadata in context.invocation_metadata():
40            if metadata.key == "authorization":
41                auth = self.__class__.AUTH_REGEX.match(metadata.value)
42
43                if auth is not None and len(auth.group(1)) > 0:
44                    authenticated = True
45
46                break
47
48        if authenticated:
49            return method(request_or_iterator, context)
50        else:
51            context.set_code(grpc.StatusCode.UNAUTHENTICATED)
52            context.set_details("Invalid or missing API key")
53            raise

Affords intercepting incoming RPCs on the service-side.

AUTH_REGEX = re.compile('^Bearer (.+)$')
def intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str) -> Any:
31    def intercept(
32        self,
33        method: Callable,
34        request_or_iterator: Any,
35        context: grpc.ServicerContext,
36        method_name: str,
37    ) -> Any:
38        authenticated = False
39        for metadata in context.invocation_metadata():
40            if metadata.key == "authorization":
41                auth = self.__class__.AUTH_REGEX.match(metadata.value)
42
43                if auth is not None and len(auth.group(1)) > 0:
44                    authenticated = True
45
46                break
47
48        if authenticated:
49            return method(request_or_iterator, context)
50        else:
51            context.set_code(grpc.StatusCode.UNAUTHENTICATED)
52            context.set_details("Invalid or missing API key")
53            raise
class ForceFailInterceptor(sift_py._internal.test_util.server_interceptor.ServerInterceptor):
56class ForceFailInterceptor(ServerInterceptor):
57    """
58    Force RPC to fail a few times before letting it pass.
59
60    `failed_attempts`: Count of how many times failed
61    `expected_num_fails`: How many times you want call to fail
62    """
63
64    failed_attempts: int
65    expected_num_fails: int
66    failure_code: grpc.StatusCode
67
68    def __init__(
69        self, expected_num_fails: int, failure_code: grpc.StatusCode = grpc.StatusCode.UNKNOWN
70    ):
71        self.expected_num_fails = expected_num_fails
72        self.failed_attempts = 0
73        self.failure_code = failure_code
74        super().__init__()
75
76    def intercept(
77        self,
78        method: Callable,
79        request_or_iterator: Any,
80        context: grpc.ServicerContext,
81        method_name: str,
82    ) -> Any:
83        if self.failed_attempts < self.expected_num_fails:
84            self.failed_attempts += 1
85            context.set_code(self.failure_code)
86            context.set_details("something unknown happened")
87            raise
88
89        return method(request_or_iterator, context)

Force RPC to fail a few times before letting it pass.

failed_attempts: Count of how many times failed expected_num_fails: How many times you want call to fail

ForceFailInterceptor( expected_num_fails: int, failure_code: grpc.StatusCode = <StatusCode.UNKNOWN: (2, 'unknown')>)
68    def __init__(
69        self, expected_num_fails: int, failure_code: grpc.StatusCode = grpc.StatusCode.UNKNOWN
70    ):
71        self.expected_num_fails = expected_num_fails
72        self.failed_attempts = 0
73        self.failure_code = failure_code
74        super().__init__()
failed_attempts: int
expected_num_fails: int
failure_code: grpc.StatusCode
def intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str) -> Any:
76    def intercept(
77        self,
78        method: Callable,
79        request_or_iterator: Any,
80        context: grpc.ServicerContext,
81        method_name: str,
82    ) -> Any:
83        if self.failed_attempts < self.expected_num_fails:
84            self.failed_attempts += 1
85            context.set_code(self.failure_code)
86            context.set_details("something unknown happened")
87            raise
88
89        return method(request_or_iterator, context)
@contextmanager
def server_spy( mocker: pytest_mock.plugin.MockerFixture, *interceptors: sift_py._internal.test_util.server_interceptor.ServerInterceptor) -> Iterator[Union[unittest.mock.MagicMock, unittest.mock.AsyncMock, unittest.mock.NonCallableMagicMock]]:
 92@contextmanager
 93def server_spy(mocker: MockFixture, *interceptors: ServerInterceptor) -> Iterator[MockType]:
 94    server = grpc.server(
 95        thread_pool=futures.ThreadPoolExecutor(max_workers=1), interceptors=list(interceptors)
 96    )
 97
 98    data_service = DataService()
 99    spy = mocker.spy(data_service, "GetData")
100
101    add_DataServiceServicer_to_server(data_service, server)
102    server.add_insecure_port("[::]:50052")
103    server.start()
104    try:
105        yield spy
106    finally:
107        server.stop(None)
108        server.wait_for_termination()
def test_sift_channel(mocker: pytest_mock.plugin.MockerFixture):
111def test_sift_channel(mocker: MockFixture):
112    with server_spy(mocker, AuthInterceptor()) as get_data_spy:
113        sift_channel_config_a: SiftChannelConfig = {
114            "uri": "localhost:50052",
115            "apikey": "",
116            "use_ssl": False,
117        }
118
119        with use_sift_channel(sift_channel_config_a) as channel:
120            with pytest.raises(grpc.RpcError, match="UNAUTHENTICATED"):
121                stub = DataServiceStub(channel)
122                _ = cast(GetDataResponse, stub.GetData(GetDataRequest()))
123
124            get_data_spy.assert_not_called()
125
126        sift_channel_config_b: SiftChannelConfig = {
127            "uri": "localhost:50052",
128            "apikey": "some-token",
129            "use_ssl": False,
130        }
131
132        with use_sift_channel(sift_channel_config_b) as channel:
133            stub = DataServiceStub(channel)
134            res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
135            assert res.next_page_token == "next-page-token"
136            get_data_spy.assert_called_once()
137
138    force_fail_interceptor = ForceFailInterceptor(4)
139    with server_spy(mocker, AuthInterceptor(), force_fail_interceptor) as get_data_spy:
140        sift_channel_config_c: SiftChannelConfig = {
141            "uri": "localhost:50052",
142            "apikey": "some-token",
143            "use_ssl": False,
144        }
145
146        with use_sift_channel(sift_channel_config_c) as channel:
147            stub = DataServiceStub(channel)
148            # This will attempt 5 times: fail 4 times, succeed on 5th
149            res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
150            assert res.next_page_token == "next-page-token"
151            get_data_spy.assert_called_once()
152
153    # fail 4 times, pass the 5th attempt
154    assert force_fail_interceptor.failed_attempts == 4
155
156    # Now we're going to fail beyond the max retry attempts
157
158    force_fail_interceptor_max = ForceFailInterceptor(7)
159    with server_spy(mocker, AuthInterceptor(), force_fail_interceptor_max) as get_data_spy:
160        sift_channel_config_d: SiftChannelConfig = {
161            "uri": "localhost:50052",
162            "apikey": "some-token",
163            "use_ssl": False,
164        }
165
166        with use_sift_channel(sift_channel_config_d) as channel:
167            stub = DataServiceStub(channel)
168
169            # This will go beyond the max number of attempts
170            with pytest.raises(Exception):
171                res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
172
173            get_data_spy.assert_not_called()
174
175    # All attempts failed
176    assert force_fail_interceptor_max.failed_attempts == 5
def test_internal_error_retry(mocker: pytest_mock.plugin.MockerFixture):
179def test_internal_error_retry(mocker: MockFixture):
180    force_fail_interceptor = ForceFailInterceptor(4, failure_code=grpc.StatusCode.INTERNAL)
181    with server_spy(mocker, AuthInterceptor(), force_fail_interceptor) as get_data_spy:
182        sift_channel_config_c: SiftChannelConfig = {
183            "uri": "localhost:50052",
184            "apikey": "some-token",
185            "use_ssl": False,
186        }
187
188        with use_sift_channel(sift_channel_config_c) as channel:
189            stub = DataServiceStub(channel)
190            # This will attempt 5 times: fail 4 times, succeed on 5th
191            res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
192            assert res.next_page_token == "next-page-token"
193            get_data_spy.assert_called_once()
194
195    # fail 4 times, pass the 5th attempt
196    assert force_fail_interceptor.failed_attempts == 4
197
198    # Now we're going to fail beyond the max retry attempts
199    force_fail_interceptor_max = ForceFailInterceptor(7)
200    with server_spy(mocker, AuthInterceptor(), force_fail_interceptor_max) as get_data_spy:
201        sift_channel_config_d: SiftChannelConfig = {
202            "uri": "localhost:50052",
203            "apikey": "some-token",
204            "use_ssl": False,
205        }
206
207        with use_sift_channel(sift_channel_config_d) as channel:
208            stub = DataServiceStub(channel)
209
210            # This will go beyond the max number of attempts
211            with pytest.raises(Exception):
212                res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
213
214            get_data_spy.assert_not_called()
215
216    # All attempts failed
217    assert force_fail_interceptor_max.failed_attempts == 5