sift_py.grpc.transport_test

  1# ruff: noqa: N802
  2
  3import re
  4from concurrent import futures
  5from contextlib import contextmanager
  6from typing import Any, Callable, Iterator, cast
  7
  8import grpc
  9import pytest
 10from pytest_mock import MockFixture, MockType
 11from sift.data.v2.data_pb2 import GetDataRequest, GetDataResponse
 12from sift.data.v2.data_pb2_grpc import (
 13    DataServiceServicer,
 14    DataServiceStub,
 15    add_DataServiceServicer_to_server,
 16)
 17
 18from sift_py._internal.test_util.server_interceptor import ServerInterceptor
 19from sift_py.grpc.transport import SiftChannelConfig, use_sift_channel
 20
 21
 22class DataService(DataServiceServicer):
 23    def GetData(self, request: GetDataRequest, context: grpc.ServicerContext):
 24        return GetDataResponse(next_page_token="next-page-token")
 25
 26
 27class AuthInterceptor(ServerInterceptor):
 28    AUTH_REGEX = re.compile(r"^Bearer (.+)$")
 29
 30    def intercept(
 31        self,
 32        method: Callable,
 33        request_or_iterator: Any,
 34        context: grpc.ServicerContext,
 35        method_name: str,
 36    ) -> Any:
 37        authenticated = False
 38        for metadata in context.invocation_metadata():
 39            if metadata.key == "authorization":
 40                auth = self.__class__.AUTH_REGEX.match(metadata.value)
 41
 42                if auth is not None and len(auth.group(1)) > 0:
 43                    authenticated = True
 44
 45                break
 46
 47        if authenticated:
 48            return method(request_or_iterator, context)
 49        else:
 50            context.set_code(grpc.StatusCode.UNAUTHENTICATED)
 51            context.set_details("Invalid or missing API key")
 52            raise
 53
 54
 55class ForceFailInterceptor(ServerInterceptor):
 56    """
 57    Force RPC to fail a few times before letting it pass.
 58
 59    `failed_attempts`: Count of how many times failed
 60    `expected_num_fails`: How many times you want call to fail
 61    """
 62
 63    failed_attempts: int
 64    expected_num_fails: int
 65
 66    def __init__(self, expected_num_fails: int):
 67        self.expected_num_fails = expected_num_fails
 68        self.failed_attempts = 0
 69        super().__init__()
 70
 71    def intercept(
 72        self,
 73        method: Callable,
 74        request_or_iterator: Any,
 75        context: grpc.ServicerContext,
 76        method_name: str,
 77    ) -> Any:
 78        if self.failed_attempts < self.expected_num_fails:
 79            self.failed_attempts += 1
 80            context.set_code(grpc.StatusCode.UNKNOWN)
 81            context.set_details("something unknown happened")
 82            raise
 83
 84        return method(request_or_iterator, context)
 85
 86
 87def test_sift_channel(mocker: MockFixture):
 88    @contextmanager
 89    def test_server_spy(*interceptors: ServerInterceptor) -> Iterator[MockType]:
 90        server = grpc.server(
 91            thread_pool=futures.ThreadPoolExecutor(max_workers=1), interceptors=list(interceptors)
 92        )
 93
 94        data_service = DataService()
 95        spy = mocker.spy(data_service, "GetData")
 96
 97        add_DataServiceServicer_to_server(data_service, server)
 98        server.add_insecure_port("[::]:50052")
 99        server.start()
100        try:
101            yield spy
102        finally:
103            server.stop(None)
104            server.wait_for_termination()
105
106    with test_server_spy(AuthInterceptor()) as get_data_spy:
107        sift_channel_config_a: SiftChannelConfig = {
108            "uri": "localhost:50052",
109            "apikey": "",
110            "use_ssl": False,
111        }
112
113        with use_sift_channel(sift_channel_config_a) as channel:
114            with pytest.raises(grpc.RpcError, match="UNAUTHENTICATED"):
115                stub = DataServiceStub(channel)
116                _ = cast(GetDataResponse, stub.GetData(GetDataRequest()))
117
118            get_data_spy.assert_not_called()
119
120        sift_channel_config_b: SiftChannelConfig = {
121            "uri": "localhost:50052",
122            "apikey": "some-token",
123            "use_ssl": False,
124        }
125
126        with use_sift_channel(sift_channel_config_b) as channel:
127            stub = DataServiceStub(channel)
128            res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
129            assert res.next_page_token == "next-page-token"
130            get_data_spy.assert_called_once()
131
132    force_fail_interceptor = ForceFailInterceptor(4)
133    with test_server_spy(AuthInterceptor(), force_fail_interceptor) as get_data_spy:
134        sift_channel_config_c: SiftChannelConfig = {
135            "uri": "localhost:50052",
136            "apikey": "some-token",
137            "use_ssl": False,
138        }
139
140        with use_sift_channel(sift_channel_config_c) as channel:
141            stub = DataServiceStub(channel)
142            # This will attempt 5 times: fail 4 times, succeed on 5th
143            res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
144            assert res.next_page_token == "next-page-token"
145            get_data_spy.assert_called_once()
146
147    # fail 4 times, pass the 5th attempt
148    assert force_fail_interceptor.failed_attempts == 4
149
150    # Now we're going to fail beyond the max retry attempts
151
152    force_fail_interceptor_max = ForceFailInterceptor(7)
153    with test_server_spy(AuthInterceptor(), force_fail_interceptor_max) as get_data_spy:
154        sift_channel_config_d: SiftChannelConfig = {
155            "uri": "localhost:50052",
156            "apikey": "some-token",
157            "use_ssl": False,
158        }
159
160        with use_sift_channel(sift_channel_config_d) as channel:
161            stub = DataServiceStub(channel)
162
163            # This will go beyond the max number of attempts
164            with pytest.raises(Exception):
165                res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
166
167            get_data_spy.assert_not_called()
168
169    # All attempts failed
170    assert force_fail_interceptor_max.failed_attempts == 5
class DataService(sift.data.v2.data_pb2_grpc.DataServiceServicer):
23class DataService(DataServiceServicer):
24    def GetData(self, request: GetDataRequest, context: grpc.ServicerContext):
25        return GetDataResponse(next_page_token="next-page-token")

Missing associated documentation comment in .proto file.

def GetData( self, request: sift.data.v2.data_pb2.GetDataRequest, context: grpc.ServicerContext):
24    def GetData(self, request: GetDataRequest, context: grpc.ServicerContext):
25        return GetDataResponse(next_page_token="next-page-token")

Query data

28class AuthInterceptor(ServerInterceptor):
29    AUTH_REGEX = re.compile(r"^Bearer (.+)$")
30
31    def intercept(
32        self,
33        method: Callable,
34        request_or_iterator: Any,
35        context: grpc.ServicerContext,
36        method_name: str,
37    ) -> Any:
38        authenticated = False
39        for metadata in context.invocation_metadata():
40            if metadata.key == "authorization":
41                auth = self.__class__.AUTH_REGEX.match(metadata.value)
42
43                if auth is not None and len(auth.group(1)) > 0:
44                    authenticated = True
45
46                break
47
48        if authenticated:
49            return method(request_or_iterator, context)
50        else:
51            context.set_code(grpc.StatusCode.UNAUTHENTICATED)
52            context.set_details("Invalid or missing API key")
53            raise

Affords intercepting incoming RPCs on the service-side.

AUTH_REGEX = re.compile('^Bearer (.+)$')
def intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str) -> Any:
31    def intercept(
32        self,
33        method: Callable,
34        request_or_iterator: Any,
35        context: grpc.ServicerContext,
36        method_name: str,
37    ) -> Any:
38        authenticated = False
39        for metadata in context.invocation_metadata():
40            if metadata.key == "authorization":
41                auth = self.__class__.AUTH_REGEX.match(metadata.value)
42
43                if auth is not None and len(auth.group(1)) > 0:
44                    authenticated = True
45
46                break
47
48        if authenticated:
49            return method(request_or_iterator, context)
50        else:
51            context.set_code(grpc.StatusCode.UNAUTHENTICATED)
52            context.set_details("Invalid or missing API key")
53            raise
class ForceFailInterceptor(sift_py._internal.test_util.server_interceptor.ServerInterceptor):
56class ForceFailInterceptor(ServerInterceptor):
57    """
58    Force RPC to fail a few times before letting it pass.
59
60    `failed_attempts`: Count of how many times failed
61    `expected_num_fails`: How many times you want call to fail
62    """
63
64    failed_attempts: int
65    expected_num_fails: int
66
67    def __init__(self, expected_num_fails: int):
68        self.expected_num_fails = expected_num_fails
69        self.failed_attempts = 0
70        super().__init__()
71
72    def intercept(
73        self,
74        method: Callable,
75        request_or_iterator: Any,
76        context: grpc.ServicerContext,
77        method_name: str,
78    ) -> Any:
79        if self.failed_attempts < self.expected_num_fails:
80            self.failed_attempts += 1
81            context.set_code(grpc.StatusCode.UNKNOWN)
82            context.set_details("something unknown happened")
83            raise
84
85        return method(request_or_iterator, context)

Force RPC to fail a few times before letting it pass.

failed_attempts: Count of how many times failed expected_num_fails: How many times you want call to fail

ForceFailInterceptor(expected_num_fails: int)
67    def __init__(self, expected_num_fails: int):
68        self.expected_num_fails = expected_num_fails
69        self.failed_attempts = 0
70        super().__init__()
failed_attempts: int
expected_num_fails: int
def intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str) -> Any:
72    def intercept(
73        self,
74        method: Callable,
75        request_or_iterator: Any,
76        context: grpc.ServicerContext,
77        method_name: str,
78    ) -> Any:
79        if self.failed_attempts < self.expected_num_fails:
80            self.failed_attempts += 1
81            context.set_code(grpc.StatusCode.UNKNOWN)
82            context.set_details("something unknown happened")
83            raise
84
85        return method(request_or_iterator, context)
def test_sift_channel(mocker: pytest_mock.plugin.MockerFixture):
 88def test_sift_channel(mocker: MockFixture):
 89    @contextmanager
 90    def test_server_spy(*interceptors: ServerInterceptor) -> Iterator[MockType]:
 91        server = grpc.server(
 92            thread_pool=futures.ThreadPoolExecutor(max_workers=1), interceptors=list(interceptors)
 93        )
 94
 95        data_service = DataService()
 96        spy = mocker.spy(data_service, "GetData")
 97
 98        add_DataServiceServicer_to_server(data_service, server)
 99        server.add_insecure_port("[::]:50052")
100        server.start()
101        try:
102            yield spy
103        finally:
104            server.stop(None)
105            server.wait_for_termination()
106
107    with test_server_spy(AuthInterceptor()) as get_data_spy:
108        sift_channel_config_a: SiftChannelConfig = {
109            "uri": "localhost:50052",
110            "apikey": "",
111            "use_ssl": False,
112        }
113
114        with use_sift_channel(sift_channel_config_a) as channel:
115            with pytest.raises(grpc.RpcError, match="UNAUTHENTICATED"):
116                stub = DataServiceStub(channel)
117                _ = cast(GetDataResponse, stub.GetData(GetDataRequest()))
118
119            get_data_spy.assert_not_called()
120
121        sift_channel_config_b: SiftChannelConfig = {
122            "uri": "localhost:50052",
123            "apikey": "some-token",
124            "use_ssl": False,
125        }
126
127        with use_sift_channel(sift_channel_config_b) as channel:
128            stub = DataServiceStub(channel)
129            res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
130            assert res.next_page_token == "next-page-token"
131            get_data_spy.assert_called_once()
132
133    force_fail_interceptor = ForceFailInterceptor(4)
134    with test_server_spy(AuthInterceptor(), force_fail_interceptor) as get_data_spy:
135        sift_channel_config_c: SiftChannelConfig = {
136            "uri": "localhost:50052",
137            "apikey": "some-token",
138            "use_ssl": False,
139        }
140
141        with use_sift_channel(sift_channel_config_c) as channel:
142            stub = DataServiceStub(channel)
143            # This will attempt 5 times: fail 4 times, succeed on 5th
144            res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
145            assert res.next_page_token == "next-page-token"
146            get_data_spy.assert_called_once()
147
148    # fail 4 times, pass the 5th attempt
149    assert force_fail_interceptor.failed_attempts == 4
150
151    # Now we're going to fail beyond the max retry attempts
152
153    force_fail_interceptor_max = ForceFailInterceptor(7)
154    with test_server_spy(AuthInterceptor(), force_fail_interceptor_max) as get_data_spy:
155        sift_channel_config_d: SiftChannelConfig = {
156            "uri": "localhost:50052",
157            "apikey": "some-token",
158            "use_ssl": False,
159        }
160
161        with use_sift_channel(sift_channel_config_d) as channel:
162            stub = DataServiceStub(channel)
163
164            # This will go beyond the max number of attempts
165            with pytest.raises(Exception):
166                res = cast(GetDataResponse, stub.GetData(GetDataRequest()))
167
168            get_data_spy.assert_not_called()
169
170    # All attempts failed
171    assert force_fail_interceptor_max.failed_attempts == 5