diff --git a/docs/source/guides/grpc_streaming.md b/docs/source/guides/grpc_streaming.md new file mode 100644 index 00000000000..311cbc1111f --- /dev/null +++ b/docs/source/guides/grpc_streaming.md @@ -0,0 +1,153 @@ +# gRPC Streaming with BentoML (v1alpha1) + +BentoML supports gRPC streaming, allowing for efficient, long-lived communication channels between clients and servers. This guide demonstrates how to define, implement, and use gRPC streaming services with BentoML's `v1alpha1` gRPC protocol. + +This `v1alpha1` protocol is an initial version focused on bi-directional streaming where the client sends a single message and the server responds with a stream of messages. + +## 1. Defining the Service (.proto) + +First, define your service and messages using Protocol Buffers. For the `v1alpha1` streaming interface, BentoML provides a specific service definition. If you were building custom services beyond the default `BentoService`, you'd create your own `.proto` similar to this. + +The core `v1alpha1` service used internally by BentoML is defined in `src/bentoml/grpc/v1alpha1/bentoml_service_v1alpha1.proto`: + +```protobuf +syntax = "proto3"; + +package bentoml.grpc.v1alpha1; + +// The BentoService service definition. +service BentoService { + // A streaming RPC method that accepts a Request message + // and returns a stream of Response messages. + rpc CallStream(Request) returns (stream Response) {} +} + +// The request message containing the input data. +message Request { + string data = 1; +} + +// The response message containing the output data. +message Response { + string data = 1; +} +``` + +Key aspects: +- `service BentoService`: Defines the service name. +- `rpc CallStream(Request) returns (stream Response) {}`: This declares a server-streaming RPC method. The client sends a single `Request`, and the server replies with a stream of `Response` messages. + +After defining your `.proto` file, you need to generate the Python gRPC stubs: +```bash +pip install grpcio-tools +python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. your_service.proto +``` +For BentoML's internal `v1alpha1` service, these stubs (`bentoml_service_v1alpha1_pb2.py` and `bentoml_service_v1alpha1_pb2_grpc.py`) are already generated and included. + +## 2. Implementing the Server-Side Streaming Logic + +You implement the server-side logic by creating a class that inherits from the generated `YourServiceServicer` (e.g., `BentoServiceServicer` for the internal service) and overriding the streaming methods. + +Here's how the internal `BentoServiceImpl` for `v1alpha1` is structured (simplified from `src/bentoml/grpc/v1alpha1/server.py`): + +```python +import asyncio +import grpc +# Assuming stubs are generated in 'generated' directory or available in path +from bentoml.grpc.v1alpha1 import bentoml_service_v1alpha1_pb2 as pb +from bentoml.grpc.v1alpha1 import bentoml_service_v1alpha1_pb2_grpc as services + +class BentoServiceImpl(services.BentoServiceServicer): + async def CallStream(self, request: pb.Request, context: grpc.aio.ServicerContext) -> pb.Response: + """ + Example CallStream implementation. + Receives a Request and yields a stream of Response messages. + """ + print(f"CallStream received: {request.data}") + for i in range(5): # Example: send 5 messages + response_data = f"Response {i+1} for '{request.data}'" + print(f"Sending: {response_data}") + await asyncio.sleep(0.5) # Simulate work + yield pb.Response(data=response_data) + print("CallStream finished.") + +# To run this service (example standalone server): +async def run_server(port=50051): + server = grpc.aio.server() + services.add_BentoServiceServicer_to_server(BentoServiceImpl(), server) + server.add_insecure_port(f"[::]:{port}") + await server.start() + print(f"gRPC server started on port {port}") + await server.wait_for_termination() + +if __name__ == "__main__": + asyncio.run(run_server()) +``` + +When integrating with `bentoml serve-grpc`, BentoML handles running the gRPC server. You need to ensure your service implementation is correctly picked up, which is done by modifying `Service.get_grpc_servicer` if you are customizing the main BentoService, or by mounting your own servicer for custom services. For the `v1alpha1` protocol, BentoML's `Service` class is already configured to use this `BentoServiceImpl`. + +## 3. Using the BentoMlGrpcClient (v1alpha1) + +BentoML provides a client SDK to interact with the `v1alpha1` gRPC streaming service. + +Example usage (from `src/bentoml/grpc/v1alpha1/client.py`): +```python +import asyncio +from bentoml.grpc.v1alpha1.client import BentoMlGrpcClient + +async def main(): + client = BentoMlGrpcClient(host="localhost", port=50051) + + input_data = "Hello Streaming World" + print(f"Calling CallStream with data: '{input_data}'") + + try: + idx = 0 + async for response in client.call_stream(data=input_data): + print(f"Received from stream (message {idx}): {response.data}") + idx += 1 + except Exception as e: + print(f"An error occurred: {e}") + finally: + await client.close() + print("Client connection closed.") + +if __name__ == "__main__": + asyncio.run(main()) +``` +The `client.call_stream(data=...)` method returns an asynchronous iterator that yields `Response` messages from the server. + +## 4. Using the `call-grpc-stream` CLI Command + +BentoML provides a CLI command to easily test and interact with `v1alpha1` gRPC streaming services. + +**Command Syntax:** +```bash +bentoml call-grpc-stream --host --port --data "" +``` + +**Example:** +Assuming your BentoML gRPC server (with `v1alpha1` protocol) is running on `localhost:50051`: +```bash +bentoml call-grpc-stream --host localhost --port 50051 --data "Test Message from CLI" +``` + +Output will be similar to: +``` +Connecting to gRPC server at localhost:50051... +Sending data: 'Test Message from CLI' to CallStream... +--- Streamed Responses --- +Response 1 for 'Test Message from CLI' +Response 2 for 'Test Message from CLI' +Response 3 for 'Test Message from CLI' +... (based on server implementation) ... +------------------------ +Connection closed. +``` + +This CLI command uses the `BentoMlGrpcClient` internally. + +## Summary + +The `v1alpha1` gRPC streaming support in BentoML provides a foundation for building services that require persistent, streamed communication. By defining services in `.proto` files, implementing the server-side logic, and using the provided client SDK or CLI, you can leverage gRPC streaming in your BentoML applications. Remember that this `v1alpha1` version is specific to a client-sends-one, server-streams-many interaction pattern for the main `BentoService`. For more complex gRPC patterns (client-streaming, bidirectional-streaming for custom services), you would define those in your own `.proto` files and implement corresponding servicers and clients. +``` diff --git a/docs/source/index.rst b/docs/source/index.rst index 948bb3d1539..e1538283c42 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -176,6 +176,7 @@ For release notes and detailed changelogs, see the `Releases `). For this example, we assume the gRPC server runs on the default port that `bentoml serve` would expose if not configured, which might require checking BentoML's default configuration or explicitly setting it. + *Note: `bentoml serve` starts HTTP server by default. For gRPC, you usually use `bentoml serve-grpc`. However, BentoML services can expose both. We will use `mount_grpc_servicer` which should make it available via the main gRPC server that `serve-grpc` would typically manage.* + + To ensure it uses a known gRPC port (e.g., 50051 if not default for `serve`), you might run: + ```bash + bentoml serve service:svc --reload --grpc-port 50051 + # Or more explicitly for gRPC focus: + # bentoml serve-grpc service:svc --reload --port 50051 + ``` + Check the output from `bentoml serve` for the actual gRPC port if you don't specify one. For this example, `client_example.py` assumes `localhost:50051`. + +2. **Run the Client**: + In a new terminal, from the `examples/grpc_streaming` directory: + ```bash + python client_example.py + ``` + +## Expected Output (Client) + +``` +Client sending: Hello, stream! +Server says: Response 1 to 'Hello, stream!' +Server says: Response 2 to 'Hello, stream!' +Server says: Response 3 to 'Hello, stream!' +Server says: Response 4 to 'Hello, stream!' +Server says: Response 5 to 'Hello, stream!' +Stream finished. +``` + +## How it Works + +- **`example_service.proto`**: Defines a `SimpleStreamingService` with a server-streaming RPC method `Chat`. +- **`service.py`**: + - Implements `SimpleStreamingServicerImpl` which provides the logic for the `Chat` method. + - Creates a BentoML `Service` named `custom_grpc_stream_example`. + - Mounts the `SimpleStreamingServicerImpl` to the BentoML service instance. When this BentoML service is run with gRPC enabled, the custom gRPC service will be available. +- **`client_example.py`**: + - Uses `grpc.insecure_channel` to connect to the server. + - Creates a stub for `SimpleStreamingService`. + - Calls the `Chat` method and iterates over the streamed responses. + +This example showcases how to integrate custom gRPC services with streaming capabilities within the BentoML framework. +``` diff --git a/examples/grpc_streaming/bentofile.yaml b/examples/grpc_streaming/bentofile.yaml new file mode 100644 index 00000000000..83dc116e879 --- /dev/null +++ b/examples/grpc_streaming/bentofile.yaml @@ -0,0 +1,21 @@ +service: "service:svc" +name: "custom_grpc_stream_example" +version: "0.1.0" + +description: "A BentoML example showcasing custom gRPC streaming services." + +# Ensure generated stubs are included if building a bento +# For local development (bentoml serve), Python's import system will find them +# if they are in the PYTHONPATH (e.g., in the same directory or an installed package). +# If you build this into a Bento, you'd want to ensure 'generated' is included. +include: + - "*.py" + - "generated/*.py" + - "protos/*.proto" + +python: + packages: + - grpcio + - grpcio-tools # For local stub generation, not strictly needed by the Bento itself at runtime + - bentoml +``` diff --git a/examples/grpc_streaming/client_example.py b/examples/grpc_streaming/client_example.py new file mode 100644 index 00000000000..65a3a3b6328 --- /dev/null +++ b/examples/grpc_streaming/client_example.py @@ -0,0 +1,50 @@ +import asyncio +import time +import uuid + +import grpc + +# Import generated gRPC stubs +from generated import example_service_pb2 +from generated import example_service_pb2_grpc + + +async def run_client(): + # Target server address + target_address = ( + "localhost:50051" # Default gRPC port for BentoML, adjust if necessary + ) + + # Create a channel + async with grpc.aio.insecure_channel(target_address) as channel: + # Create a stub (client) + stub = example_service_pb2_grpc.SimpleStreamingServiceStub(channel) + + # Prepare a request message + client_message_text = "Hello, stream!" + request_message = example_service_pb2.ChatMessage( + message_id=str(uuid.uuid4()), # Unique ID for the message + text=client_message_text, + timestamp=int(time.time() * 1000), + ) + + print(f"Client sending: {client_message_text}") + + try: + # Call the Chat RPC method and iterate through the streamed responses + async for response in stub.Chat(request_message): + print( + f"Server says: {response.text} (ID: {response.message_id}, TS: {response.timestamp})" + ) + + print("Stream finished.") + + except grpc.aio.AioRpcError as e: + print(f"gRPC call failed: {e.code()} - {e.details()}") + except Exception as e: + print(f"An error occurred: {e}") + + +if __name__ == "__main__": + print("Starting gRPC client example...") + asyncio.run(run_client()) diff --git a/examples/grpc_streaming/generated/__init__.py b/examples/grpc_streaming/generated/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/grpc_streaming/generated/example_service_pb2.py b/examples/grpc_streaming/generated/example_service_pb2.py new file mode 100644 index 00000000000..78b8b9f468f --- /dev/null +++ b/examples/grpc_streaming/generated/example_service_pb2.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: example_service.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" + +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, 5, 29, 0, "", "example_service.proto" +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x15\x65xample_service.proto\x12\x14\x62\x65ntoml.example.grpc"B\n\x0b\x43hatMessage\x12\x12\n\nmessage_id\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x32j\n\x16SimpleStreamingService\x12P\n\x04\x43hat\x12!.bentoml.example.grpc.ChatMessage\x1a!.bentoml.example.grpc.ChatMessage"\x00\x30\x01\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "example_service_pb2", _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals["_CHATMESSAGE"]._serialized_start = 47 + _globals["_CHATMESSAGE"]._serialized_end = 113 + _globals["_SIMPLESTREAMINGSERVICE"]._serialized_start = 115 + _globals["_SIMPLESTREAMINGSERVICE"]._serialized_end = 221 +# @@protoc_insertion_point(module_scope) diff --git a/examples/grpc_streaming/generated/example_service_pb2_grpc.py b/examples/grpc_streaming/generated/example_service_pb2_grpc.py new file mode 100644 index 00000000000..464b655d104 --- /dev/null +++ b/examples/grpc_streaming/generated/example_service_pb2_grpc.py @@ -0,0 +1,106 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" + +import example_service_pb2 as example__service__pb2 +import grpc + +GRPC_GENERATED_VERSION = "1.71.0" +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + + _version_not_supported = first_version_is_lower( + GRPC_VERSION, GRPC_GENERATED_VERSION + ) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f"The grpc package installed is at version {GRPC_VERSION}," + + " but the generated code in example_service_pb2_grpc.py depends on" + + f" grpcio>={GRPC_GENERATED_VERSION}." + + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." + ) + + +class SimpleStreamingServiceStub(object): + """A simple service for demonstrating gRPC streaming.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Chat = channel.unary_stream( + "/bentoml.example.grpc.SimpleStreamingService/Chat", + request_serializer=example__service__pb2.ChatMessage.SerializeToString, + response_deserializer=example__service__pb2.ChatMessage.FromString, + _registered_method=True, + ) + + +class SimpleStreamingServiceServicer(object): + """A simple service for demonstrating gRPC streaming.""" + + def Chat(self, request, context): + """Client sends a single message, server streams back multiple messages.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_SimpleStreamingServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + "Chat": grpc.unary_stream_rpc_method_handler( + servicer.Chat, + request_deserializer=example__service__pb2.ChatMessage.FromString, + response_serializer=example__service__pb2.ChatMessage.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "bentoml.example.grpc.SimpleStreamingService", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers( + "bentoml.example.grpc.SimpleStreamingService", rpc_method_handlers + ) + + +# This class is part of an EXPERIMENTAL API. +class SimpleStreamingService(object): + """A simple service for demonstrating gRPC streaming.""" + + @staticmethod + def Chat( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, + target, + "/bentoml.example.grpc.SimpleStreamingService/Chat", + example__service__pb2.ChatMessage.SerializeToString, + example__service__pb2.ChatMessage.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) diff --git a/examples/grpc_streaming/protos/example_service.proto b/examples/grpc_streaming/protos/example_service.proto new file mode 100644 index 00000000000..4129d7c354d --- /dev/null +++ b/examples/grpc_streaming/protos/example_service.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package bentoml.example.grpc; + +// A simple service for demonstrating gRPC streaming. +service SimpleStreamingService { + // Client sends a single message, server streams back multiple messages. + rpc Chat(ChatMessage) returns (stream ChatMessage) {} +} + +// Message for chat communications. +message ChatMessage { + string message_id = 1; + string text = 2; + int64 timestamp = 3; +} diff --git a/examples/grpc_streaming/service.py b/examples/grpc_streaming/service.py new file mode 100644 index 00000000000..df186d965b0 --- /dev/null +++ b/examples/grpc_streaming/service.py @@ -0,0 +1,60 @@ +import asyncio +import time + +# Import generated gRPC stubs +from generated import example_service_pb2 +from generated import example_service_pb2_grpc + +import bentoml +from bentoml.io import ( + Text, # For any potential REST/HTTP endpoints, not used by gRPC directly +) + + +# Servicer implementation +class SimpleStreamingServicerImpl( + example_service_pb2_grpc.SimpleStreamingServiceServicer +): + async def Chat(self, request: example_service_pb2.ChatMessage, context): + print( + f"Received chat message from client: '{request.text}' (ID: {request.message_id})" + ) + + for i in range(5): # Stream back 5 messages + response_text = f"Response {i + 1} to '{request.text}'" + timestamp = int(time.time() * 1000) # Current timestamp in milliseconds + message_id = f"server-msg-{timestamp}-{i}" + + print(f"Sending: '{response_text}' (ID: {message_id})") + yield example_service_pb2.ChatMessage( + message_id=message_id, text=response_text, timestamp=timestamp + ) + await asyncio.sleep(0.5) # Simulate some work or delay + print("Finished streaming responses for Chat.") + + +# Create a BentoML service +svc = bentoml.Service( + name="custom_grpc_stream_example_service", + runners={}, # No runners needed for this simple example +) + +# Mount the gRPC servicer +# The gRPC server will be started by BentoML when using `bentoml serve-grpc` or `bentoml serve` +# (if gRPC is enabled in config or via CLI options). +simple_servicer = SimpleStreamingServicerImpl() +svc.mount_grpc_servicer( + servicer_cls=SimpleStreamingServicerImpl, # The class of the servicer + add_servicer_fn=example_service_pb2_grpc.add_SimpleStreamingServiceServicer_to_server, # The function to add it + service_names=[ + example_service_pb2.DESCRIPTOR.services_by_name[ + "SimpleStreamingService" + ].full_name + ], # Service names +) + + +# Example of a simple REST endpoint (optional, just to show it can coexist) +@svc.api(input=Text(), output=Text()) +def greet(input_text: str) -> str: + return f"Hello, {input_text}! This is the REST endpoint." diff --git a/src/bentoml/_internal/server/grpc_app.py b/src/bentoml/_internal/server/grpc_app.py index 8b852ae9b96..7e57592ac81 100644 --- a/src/bentoml/_internal/server/grpc_app.py +++ b/src/bentoml/_internal/server/grpc_app.py @@ -85,7 +85,12 @@ def __init__( compression: grpc.Compression | None = None, protocol_version: str = LATEST_PROTOCOL_VERSION, ): - pb, _ = import_generated_stubs(protocol_version) + if protocol_version == "v1alpha1": + pb, _ = import_generated_stubs( + protocol_version, file="bentoml_service_v1alpha1.proto" + ) + else: + pb, _ = import_generated_stubs(protocol_version) self.bento_service = bento_service self.servicer = bento_service.get_grpc_servicer(protocol_version) @@ -290,7 +295,15 @@ def on_startup(self) -> list[LifecycleHook]: async def startup(self) -> None: from ...exceptions import MissingDependencyException - _, services = import_generated_stubs(self.protocol_version) + if self.protocol_version == "v1alpha1": + # For v1alpha1, the servicer returned by get_grpc_servicer is BentoServiceImpl, + # and the pb2_grpc module is bentoml_service_v1alpha1_pb2_grpc + _, services = import_generated_stubs( + self.protocol_version, file="bentoml_service_v1alpha1.proto" + ) + else: + # For other versions, it uses the standard service.proto + _, services = import_generated_stubs(self.protocol_version) # Running on_startup callback. for handler in self.on_startup: diff --git a/src/bentoml/_internal/service/service.py b/src/bentoml/_internal/service/service.py index 0d7da301420..255d22717f3 100644 --- a/src/bentoml/_internal/service/service.py +++ b/src/bentoml/_internal/service/service.py @@ -409,10 +409,18 @@ def get_grpc_servicer( Returns: A bento gRPC servicer implementation. """ - return importlib.import_module( - f".grpc.servicer.{protocol_version}", - package="bentoml._internal.server", - ).create_bento_servicer(self) + if protocol_version == "v1alpha1": + from bentoml.grpc.v1alpha1.server import BentoServiceImpl + + # This part might need adjustment depending on how BentoServiceImpl is instantiated + # and if it needs the `Service` instance (`self`) passed to it. + # For now, assuming BentoServiceImpl() is self-contained or uses context. + return BentoServiceImpl() # type: ignore # TODO: fix type hint + else: + return importlib.import_module( + f".grpc.servicer.{protocol_version}", + package="bentoml._internal.server", + ).create_bento_servicer(self) @property def grpc_servicer(self): diff --git a/src/bentoml/grpc/v1alpha1/bentoml_service_v1alpha1.proto b/src/bentoml/grpc/v1alpha1/bentoml_service_v1alpha1.proto new file mode 100644 index 00000000000..dffa55e15fb --- /dev/null +++ b/src/bentoml/grpc/v1alpha1/bentoml_service_v1alpha1.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package bentoml.grpc.v1alpha1; + +// The BentoService service definition. +service BentoService { + // A streaming RPC method that accepts a Request message and returns a stream of Response messages. + rpc CallStream(Request) returns (stream Response) {} +} + +// The request message containing the input data. +message Request { + string data = 1; +} + +// The response message containing the output data. +message Response { + string data = 1; +} diff --git a/src/bentoml/grpc/v1alpha1/bentoml_service_v1alpha1_pb2.py b/src/bentoml/grpc/v1alpha1/bentoml_service_v1alpha1_pb2.py new file mode 100644 index 00000000000..3e38775c86a --- /dev/null +++ b/src/bentoml/grpc/v1alpha1/bentoml_service_v1alpha1_pb2.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: bentoml_service_v1alpha1.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'bentoml_service_v1alpha1.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1e\x62\x65ntoml_service_v1alpha1.proto\x12\x15\x62\x65ntoml.grpc.v1alpha1\"\x17\n\x07Request\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\t\"\x18\n\x08Response\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\t2a\n\x0c\x42\x65ntoService\x12Q\n\nCallStream\x12\x1e.bentoml.grpc.v1alpha1.Request\x1a\x1f.bentoml.grpc.v1alpha1.Response\"\x00\x30\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'bentoml_service_v1alpha1_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_REQUEST']._serialized_start=57 + _globals['_REQUEST']._serialized_end=80 + _globals['_RESPONSE']._serialized_start=82 + _globals['_RESPONSE']._serialized_end=106 + _globals['_BENTOSERVICE']._serialized_start=108 + _globals['_BENTOSERVICE']._serialized_end=205 +# @@protoc_insertion_point(module_scope) diff --git a/src/bentoml/grpc/v1alpha1/bentoml_service_v1alpha1_pb2_grpc.py b/src/bentoml/grpc/v1alpha1/bentoml_service_v1alpha1_pb2_grpc.py new file mode 100644 index 00000000000..356ef3382d7 --- /dev/null +++ b/src/bentoml/grpc/v1alpha1/bentoml_service_v1alpha1_pb2_grpc.py @@ -0,0 +1,101 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +import bentoml_service_v1alpha1_pb2 as bentoml__service__v1alpha1__pb2 + +GRPC_GENERATED_VERSION = '1.71.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in bentoml_service_v1alpha1_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class BentoServiceStub(object): + """The BentoService service definition. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.CallStream = channel.unary_stream( + '/bentoml.grpc.v1alpha1.BentoService/CallStream', + request_serializer=bentoml__service__v1alpha1__pb2.Request.SerializeToString, + response_deserializer=bentoml__service__v1alpha1__pb2.Response.FromString, + _registered_method=True) + + +class BentoServiceServicer(object): + """The BentoService service definition. + """ + + def CallStream(self, request, context): + """A streaming RPC method that accepts a Request message and returns a stream of Response messages. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_BentoServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'CallStream': grpc.unary_stream_rpc_method_handler( + servicer.CallStream, + request_deserializer=bentoml__service__v1alpha1__pb2.Request.FromString, + response_serializer=bentoml__service__v1alpha1__pb2.Response.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'bentoml.grpc.v1alpha1.BentoService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('bentoml.grpc.v1alpha1.BentoService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class BentoService(object): + """The BentoService service definition. + """ + + @staticmethod + def CallStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/bentoml.grpc.v1alpha1.BentoService/CallStream', + bentoml__service__v1alpha1__pb2.Request.SerializeToString, + bentoml__service__v1alpha1__pb2.Response.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/bentoml/grpc/v1alpha1/client.py b/src/bentoml/grpc/v1alpha1/client.py new file mode 100644 index 00000000000..181250720bd --- /dev/null +++ b/src/bentoml/grpc/v1alpha1/client.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import grpc +import asyncio +from typing import AsyncIterator, Any + +from . import bentoml_service_v1alpha1_pb2 as pb +from . import bentoml_service_v1alpha1_pb2_grpc as services + +class BentoMlGrpcClient: + """ + A gRPC client for BentoML v1alpha1 BentoService. + """ + + def __init__(self, host: str, port: int | None = None, channel: grpc.aio.Channel | None = None): + """ + Initialize the BentoML gRPC client. + + Args: + host: The host address of the gRPC server. + port: The port of the gRPC server. Required if 'channel' is not provided. + channel: An existing grpc.aio.Channel. If provided, 'host' and 'port' are ignored. + """ + if channel: + self._channel = channel + else: + if port is None: + raise ValueError("Either 'channel' or 'port' must be provided.") + self._address = f"{host}:{port}" + # TODO(PROTOCOL_BUFFERS_OVER_CHANNEL_SIZE_LIMIT): configure max message length + # See https://github.com/grpc/grpc/blob/master/include/grpc/impl/codegen/grpc_types.h#L320 + # options.append(("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH)) + # options.append(("grpc.max_send_message_length", MAX_MESSAGE_LENGTH)) + self._channel = grpc.aio.insecure_channel(self._address) + + self._stub = services.BentoServiceStub(self._channel) + + async def health_check(self) -> Any: + """ + Performs a health check on the gRPC server. + This typically requires grpc_health_checking to be installed and configured on the server. + For now, this is a placeholder or would need to call a specific health check RPC if available. + The standard health check service might not be directly exposed via BentoServiceStub. + """ + # This is a simplified check; real health check would use HealthStub + # from grpc_health.v1.health_pb2_grpc import HealthStub + # health_stub = HealthStub(self._channel) + # request = health_pb2.HealthCheckRequest(service="bentoml.grpc.v1alpha1.BentoService") + # return await health_stub.Check(request) + print("Health check: Channel connectivity check.") + try: + # Try to connect and see if the channel is ready + # This is a very basic check, not a formal gRPC health check + await self._channel.channel_ready() + return "Channel is ready." + except grpc.aio.AioRpcError as e: + return f"Channel is not ready: {e.code()}" + + + async def call_stream(self, data: str) -> AsyncIterator[pb.Response]: + """ + Calls the CallStream RPC method on the server. + + Args: + data: The string data to send in the request. + + Returns: + An async iterator yielding Response messages from the server. + """ + request = pb.Request(data=data) + try: + async for response in self._stub.CallStream(request): + yield response + except grpc.aio.AioRpcError as e: + # Handle potential errors, e.g., server unavailable, etc. + print(f"gRPC call failed: {e.code()} - {e.details()}") + # Depending on desired error handling, you might raise an exception here + # or yield some error indication. For now, just printing and stopping iteration. + return + + async def close(self): + """ + Closes the gRPC channel. + """ + if self._channel: + await self._channel.close() + +async def main(): + """ + Example usage of the BentoMlGrpcClient. + Assumes a gRPC server is running on localhost:50051. + """ + # Example: Start the server from src/bentoml/grpc/v1alpha1/server.py in a separate terminal + # python src/bentoml/grpc/v1alpha1/server.py + + client = BentoMlGrpcClient(host="localhost", port=50051) + + print("--- Health Check ---") + health_status = await client.health_check() + print(f"Health status: {health_status}") + print("\n--- Calling CallStream ---") + input_data = "Hello from Python client" + print(f"Sending data: '{input_data}'") + + try: + i = 0 + async for response in client.call_stream(data=input_data): + print(f"Response {i}: {response.data}") + i += 1 + finally: + await client.close() + print("\nClient closed.") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/bentoml/grpc/v1alpha1/examples/grpc_streaming/generated/__init__.py b/src/bentoml/grpc/v1alpha1/examples/grpc_streaming/generated/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/bentoml/grpc/v1alpha1/server.py b/src/bentoml/grpc/v1alpha1/server.py new file mode 100644 index 00000000000..ccf081e769c --- /dev/null +++ b/src/bentoml/grpc/v1alpha1/server.py @@ -0,0 +1,28 @@ +import asyncio +import grpc +from . import bentoml_service_v1alpha1_pb2 as pb +from . import bentoml_service_v1alpha1_pb2_grpc as services + +class BentoServiceImpl(services.BentoServiceServicer): + async def CallStream(self, request: pb.Request, context: grpc.aio.ServicerContext) -> pb.Response: + """ + A streaming RPC method that accepts a Request message and returns a stream of Response messages. + """ + print(f"Received request: {request.data}") + for i in range(3): + await asyncio.sleep(0.5) + yield pb.Response(data=f"Response {i} for request: {request.data}") + +async def start_grpc_server(port: int, service_instances): + server = grpc.aio.server() + services.add_BentoServiceServicer_to_server(BentoServiceImpl(), server) + server.add_insecure_port(f"[::]:{port}") + await server.start() + print(f"gRPC server started on port {port}") + await server.wait_for_termination() + +if __name__ == "__main__": + # This is for testing purposes only + async def main(): + await start_grpc_server(50051, []) + asyncio.run(main()) diff --git a/src/bentoml/grpc/v1alpha1/service.proto b/src/bentoml/grpc/v1alpha1/service.proto deleted file mode 100644 index 3fe8feeccf2..00000000000 --- a/src/bentoml/grpc/v1alpha1/service.proto +++ /dev/null @@ -1,282 +0,0 @@ -syntax = "proto3"; - -package bentoml.grpc.v1alpha1; - -import "google/protobuf/struct.proto"; -import "google/protobuf/wrappers.proto"; - -// cc_enable_arenas pre-allocate memory for given message to improve speed. (C++ only) -option cc_enable_arenas = true; -option go_package = "github.com/bentoml/bentoml/grpc/v1alpha1;service"; -option java_multiple_files = true; -option java_outer_classname = "ServiceProto"; -option java_package = "com.bentoml.grpc.v1alpha1"; -option objc_class_prefix = "SVC"; -option py_generic_services = true; - -// a gRPC BentoServer. -service BentoService { - // Call handles methodcaller of given API entrypoint. - rpc Call(Request) returns (Response) {} -} - -// Request message for incoming Call. -message Request { - // api_name defines the API entrypoint to call. - // api_name is the name of the function defined in bentoml.Service. - // Example: - // - // @svc.api(input=NumpyNdarray(), output=File()) - // def predict(input: NDArray[float]) -> bytes: - // ... - // - // api_name is "predict" in this case. - string api_name = 1; - - oneof content { - // NDArray represents a n-dimensional array of arbitrary type. - NDArray ndarray = 3; - - // DataFrame represents any tabular data type. We are using - // DataFrame as a trivial representation for tabular type. - DataFrame dataframe = 5; - - // Series portrays a series of values. This can be used for - // representing Series types in tabular data. - Series series = 6; - - // File represents for any arbitrary file type. This can be - // plaintext, image, video, audio, etc. - File file = 7; - - // Text represents a string inputs. - google.protobuf.StringValue text = 8; - - // JSON is represented by using google.protobuf.Value. - // see https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/struct.proto - google.protobuf.Value json = 9; - - // Multipart represents a multipart message. - // It comprises of a mapping from given type name to a subset of aforementioned types. - Multipart multipart = 10; - - // serialized_bytes is for data serialized in BentoML's internal serialization format. - bytes serialized_bytes = 2; - } - - // Tensor is similiar to ndarray but with a name - // We are reserving it for now for future use. - // repeated Tensor tensors = 4; - reserved 4, 11 to 13; -} - -// Request message for incoming Call. -message Response { - oneof content { - // NDArray represents a n-dimensional array of arbitrary type. - NDArray ndarray = 1; - - // DataFrame represents any tabular data type. We are using - // DataFrame as a trivial representation for tabular type. - DataFrame dataframe = 3; - - // Series portrays a series of values. This can be used for - // representing Series types in tabular data. - Series series = 5; - - // File represents for any arbitrary file type. This can be - // plaintext, image, video, audio, etc. - File file = 6; - - // Text represents a string inputs. - google.protobuf.StringValue text = 7; - - // JSON is represented by using google.protobuf.Value. - // see https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/struct.proto - google.protobuf.Value json = 8; - - // Multipart represents a multipart message. - // It comprises of a mapping from given type name to a subset of aforementioned types. - Multipart multipart = 9; - - // serialized_bytes is for data serialized in BentoML's internal serialization format. - bytes serialized_bytes = 2; - } - // Tensor is similiar to ndarray but with a name - // We are reserving it for now for future use. - // repeated Tensor tensors = 4; - reserved 4, 10 to 13; -} - -// Part represents possible value types for multipart message. -// These are the same as the types in Request message. -message Part { - oneof representation { - // NDArray represents a n-dimensional array of arbitrary type. - NDArray ndarray = 1; - - // DataFrame represents any tabular data type. We are using - // DataFrame as a trivial representation for tabular type. - DataFrame dataframe = 3; - - // Series portrays a series of values. This can be used for - // representing Series types in tabular data. - Series series = 5; - - // File represents for any arbitrary file type. This can be - // plaintext, image, video, audio, etc. - File file = 6; - - // Text represents a string inputs. - google.protobuf.StringValue text = 7; - - // JSON is represented by using google.protobuf.Value. - // see https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/struct.proto - google.protobuf.Value json = 8; - - // serialized_bytes is for data serialized in BentoML's internal serialization format. - bytes serialized_bytes = 4; - } - - // Tensor is similiar to ndarray but with a name - // We are reserving it for now for future use. - // Tensor tensors = 4; - reserved 2, 9 to 13; -} - -// Multipart represents a multipart message. -// It comprises of a mapping from given type name to a subset of aforementioned types. -message Multipart { - map fields = 1; -} - -// File represents for any arbitrary file type. This can be -// plaintext, image, video, audio, etc. -message File { - // FileType represents possible file type to be handled by BentoML. - // Currently, we only support plaintext (Text()), image (Image()), and file (File()). - // TODO: support audio and video streaming file types. - enum FileType { - FILE_TYPE_UNSPECIFIED = 0; - - // file types - FILE_TYPE_CSV = 1; - FILE_TYPE_PLAINTEXT = 2; - FILE_TYPE_JSON = 3; - FILE_TYPE_BYTES = 4; - FILE_TYPE_PDF = 5; - - // image types - FILE_TYPE_PNG = 6; - FILE_TYPE_JPEG = 7; - FILE_TYPE_GIF = 8; - FILE_TYPE_BMP = 9; - FILE_TYPE_TIFF = 10; - FILE_TYPE_WEBP = 11; - FILE_TYPE_SVG = 12; - } - - // optional type of file, let it be csv, text, parquet, etc. - optional FileType kind = 1; - - // contents of file as bytes. - bytes content = 2; -} - -// DataFrame represents any tabular data type. We are using -// DataFrame as a trivial representation for tabular type. -// This message carries given implementation of tabular data based on given orientation. -// TODO: support index, records, etc. -message DataFrame { - // columns name - repeated string column_names = 1; - - // columns orient. - // { column ↠ { index ↠ value } } - repeated Series columns = 2; -} - -// Series portrays a series of values. This can be used for -// representing Series types in tabular data. -message Series { - // A bool parameter value - repeated bool bool_values = 1 [packed = true]; - - // A float parameter value - repeated float float_values = 2 [packed = true]; - - // A int32 parameter value - repeated int32 int32_values = 3 [packed = true]; - - // A int64 parameter value - repeated int64 int64_values = 6 [packed = true]; - - // A string parameter value - repeated string string_values = 5; - - // represents a double parameter value. - repeated double double_values = 4 [packed = true]; -} - -// NDArray represents a n-dimensional array of arbitrary type. -message NDArray { - // Represents data type of a given array. - enum DType { - // Represents a None type. - DTYPE_UNSPECIFIED = 0; - - // Represents an float type. - DTYPE_FLOAT = 1; - - // Represents an double type. - DTYPE_DOUBLE = 2; - - // Represents a bool type. - DTYPE_BOOL = 3; - - // Represents an int32 type. - DTYPE_INT32 = 4; - - // Represents an int64 type. - DTYPE_INT64 = 5; - - // Represents a uint32 type. - DTYPE_UINT32 = 6; - - // Represents a uint64 type. - DTYPE_UINT64 = 7; - - // Represents a string type. - DTYPE_STRING = 8; - } - - // DTYPE is the data type of given array - DType dtype = 1; - - // shape is the shape of given array. - repeated int32 shape = 2; - - // represents a string parameter value. - repeated string string_values = 5; - - // represents a float parameter value. - repeated float float_values = 3 [packed = true]; - - // represents a double parameter value. - repeated double double_values = 4 [packed = true]; - - // represents a bool parameter value. - repeated bool bool_values = 6 [packed = true]; - - // represents a int32 parameter value. - repeated int32 int32_values = 7 [packed = true]; - - // represents a int64 parameter value. - repeated int64 int64_values = 8 [packed = true]; - - // represents a uint32 parameter value. - repeated uint32 uint32_values = 9 [packed = true]; - - // represents a uint64 parameter value. - repeated uint64 uint64_values = 10 [packed = true]; -} diff --git a/src/bentoml_cli/call_grpc_stream.py b/src/bentoml_cli/call_grpc_stream.py new file mode 100644 index 00000000000..a07d204e8d4 --- /dev/null +++ b/src/bentoml_cli/call_grpc_stream.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import asyncio +import sys + +import click + +from bentoml.grpc.v1alpha1.client import BentoMlGrpcClient + + +@click.command(name="call-grpc-stream") +@click.option( + "--host", + type=click.STRING, + default="localhost", + help="The host address of the gRPC server.", + show_default=True, +) +@click.option( + "--port", + type=click.INT, + default=50051, + help="The port of the gRPC server.", + show_default=True, +) +@click.option( + "--data", + type=click.STRING, + required=True, + help="The string data to send to the CallStream method.", +) +def call_grpc_stream_command(host: str, port: int, data: str) -> None: + """ + Call the CallStream gRPC method of a BentoML v1alpha1 service. + This command connects to a gRPC server and invokes the CallStream method, + printing each streamed response to standard output. + """ + + async def _main(): + client = BentoMlGrpcClient(host=host, port=port) + try: + print(f"Connecting to gRPC server at {host}:{port}...") + print(f"Sending data: '{data}' to CallStream...") + print("--- Streamed Responses ---") + async for response in client.call_stream(data=data): + # Assuming response.data is the field to print. + # Adjust if your Response message structure is different. + print(response.data) + except Exception as e: + print(f"An error occurred: {e}", file=sys.stderr) + finally: + if client: + await client.close() + print("------------------------") + print("Connection closed.") + + asyncio.run(_main()) + + +if __name__ == "__main__": + # This allows running the command directly for testing if needed, + # though it's primarily designed to be invoked via the bentoml CLI group. + call_grpc_stream_command() diff --git a/src/bentoml_cli/cli.py b/src/bentoml_cli/cli.py index a6e2f1de821..e1d63b86348 100644 --- a/src/bentoml_cli/cli.py +++ b/src/bentoml_cli/cli.py @@ -8,6 +8,7 @@ def create_bentoml_cli() -> click.Command: from bentoml._internal.configuration import BENTOML_VERSION from bentoml._internal.context import server_context from bentoml_cli.bentos import bento_command + from bentoml_cli.call_grpc_stream import call_grpc_stream_command # New import from bentoml_cli.cloud import cloud_command from bentoml_cli.containerize import containerize_command from bentoml_cli.deployment import codespace @@ -48,6 +49,7 @@ def bentoml_cli(): bentoml_cli.add_command(codespace) bentoml_cli.add_command(deployment_command) bentoml_cli.add_command(secret_command) + bentoml_cli.add_command(call_grpc_stream_command) # New command # Load commands from extensions for ep in get_entry_points("bentoml.commands"): bentoml_cli.add_command(ep.load()) diff --git a/tests/benchmark/benchmark_streaming.py b/tests/benchmark/benchmark_streaming.py new file mode 100644 index 00000000000..8696d5b5a12 --- /dev/null +++ b/tests/benchmark/benchmark_streaming.py @@ -0,0 +1,390 @@ +import argparse +import asyncio +import statistics +import time +from typing import Any +from typing import Dict + +import grpc # For status codes +import httpx + +# Assuming the client is in the BentoML package and accessible +# Adjust the import path if necessary based on your project structure +from bentoml.grpc.v1alpha1.client import BentoMlGrpcClient + +# --- Constants --- +DEFAULT_ITERATIONS = 10 +DEFAULT_GRPC_PORT = 50051 +DEFAULT_REST_PORT = 3000 # Commonly used by BentoML REST servers +DEFAULT_HOST = "localhost" + + +# --- Helper Functions --- +def generate_payload(size: int) -> str: + """Generates a string payload of a given size.""" + return "a" * size + + +async def safe_close_grpc_client(client: BentoMlGrpcClient | None): + if client: + await client.close() + + +# --- gRPC Benchmark --- +async def benchmark_grpc_stream( + host: str, port: int, payload_size: int, stream_length: int, iterations: int +) -> Dict[str, Any]: + """Benchmarks gRPC CallStream.""" + client = None + timings = [] + first_response_latencies = [] + total_bytes_streamed_list = [] + successful_iterations = 0 + + payload = generate_payload(payload_size) + print( + f"gRPC: Payload size: {payload_size} bytes, Stream length: {stream_length} messages, Iterations: {iterations}" + ) + + try: + client = BentoMlGrpcClient(host=host, port=port) + # Warm-up call (optional, but good for stable measurements) + async for _ in client.call_stream(data="warmup"): + pass + + for i in range(iterations): + start_time = time.perf_counter() + first_response_time = None + + message_count = 0 + current_iteration_bytes = 0 + try: + async for response in client.call_stream(data=payload): + if first_response_time is None: + first_response_time = time.perf_counter() + message_count += 1 + current_iteration_bytes += len(response.data.encode("utf-8")) + if ( + message_count >= stream_length + ): # Ensure we don't stream indefinitely if server sends more + break + end_time = time.perf_counter() + + if message_count == stream_length: + timings.append(end_time - start_time) + if first_response_time: + first_response_latencies.append( + first_response_time - start_time + ) + total_bytes_streamed_list.append(current_iteration_bytes) + successful_iterations += 1 + else: + print( + f"gRPC Iteration {i + 1}: Failed - Expected {stream_length} messages, got {message_count}" + ) + + except grpc.aio.AioRpcError as e: + print( + f"gRPC Iteration {i + 1}: Failed with gRPC error - {e.code()}: {e.details()}" + ) + except Exception as e: + print(f"gRPC Iteration {i + 1}: Failed with error - {e}") + await asyncio.sleep(0.01) # Small delay between iterations + + finally: + await safe_close_grpc_client(client) + + if not timings: + return {"error": "No successful gRPC iterations."} + + avg_time = statistics.mean(timings) + std_dev_time = statistics.stdev(timings) if len(timings) > 1 else 0 + avg_first_response_latency = ( + statistics.mean(first_response_latencies) if first_response_latencies else -1 + ) + + total_bytes_per_iteration = statistics.mean(total_bytes_streamed_list) + throughput_mps = ( + successful_iterations / sum(timings) if sum(timings) > 0 else 0 + ) # Total successful messages / total time for successful + throughput_bps = ( + sum(total_bytes_streamed_list) / sum(timings) if sum(timings) > 0 else 0 + ) + + return { + "successful_iterations": successful_iterations, + "avg_stream_time_s": avg_time, + "std_dev_stream_time_s": std_dev_time, + "avg_first_response_latency_s": avg_first_response_latency, + "throughput_streams_per_s": 1 / avg_time if avg_time > 0 else 0, + "throughput_msgs_per_s": throughput_mps + * stream_length, # average messages per second across successful streams + "throughput_bytes_per_s": throughput_bps, + "avg_bytes_per_stream": total_bytes_per_iteration, + } + + +# --- REST (HTTP/1.1 Streaming) Benchmark --- +async def benchmark_rest_stream( + host: str, + port: int, + payload_size: int, + stream_length: int, + iterations: int, + endpoint: str = "/stream", # Assuming a /stream endpoint for REST +) -> Dict[str, Any]: + """Benchmarks REST streaming (e.g., line-delimited JSON or SSE).""" + timings = [] + first_response_latencies = [] + total_bytes_streamed_list = [] + successful_iterations = 0 + + payload_str = generate_payload(payload_size) + # For REST, we'd typically send JSON. The server would need to handle this. + # The 'data' field matches the gRPC Request message. + # We also need to inform the server about the desired stream length, + # as HTTP streaming doesn't inherently have a message count like gRPC stream definition. + # This could be a header or part of the JSON payload. + # For this example, let's assume the server knows to send `stream_length` messages + # or we pass it as a query parameter or in the payload. + # Let's add it to the JSON payload for simplicity here. + json_payload_with_length = {"data": payload_str, "stream_length": stream_length} + + print( + f"REST: Payload size: {payload_size} bytes, Stream length: {stream_length} messages, Iterations: {iterations}" + ) + + async with httpx.AsyncClient(base_url=f"http://{host}:{port}") as client: + # Warm-up call + try: + async with client.stream( + "POST", endpoint, json={"data": "warmup", "stream_length": 1} + ) as response: + async for _ in response.aiter_lines(): + pass + except httpx.RequestError as e: + print( + f"REST Warmup failed: {e}. Ensure REST server is running and endpoint '{endpoint}' exists." + ) + # return {"error": f"Warmup failed: {e}"} # Or decide to continue + + for i in range(iterations): + start_time = time.perf_counter() + first_response_time = None + message_count = 0 + current_iteration_bytes = 0 + try: + async with client.stream( + "POST", endpoint, json=json_payload_with_length, timeout=30.0 + ) as response: + if response.status_code != 200: + print( + f"REST Iteration {i + 1}: Failed - Status {response.status_code}, {await response.aread()}" + ) + continue + + # Assuming server sends line-delimited text, each line is a "message" + async for line in response.aiter_lines(): + if first_response_time is None: + first_response_time = time.perf_counter() + message_count += 1 + current_iteration_bytes += len(line.encode("utf-8")) + # No break here, assume server sends exactly stream_length messages based on input + end_time = time.perf_counter() + + if ( + message_count == stream_length + ): # Validate if server sent expected number of messages + timings.append(end_time - start_time) + if first_response_time: + first_response_latencies.append( + first_response_time - start_time + ) + total_bytes_streamed_list.append(current_iteration_bytes) + successful_iterations += 1 + else: + print( + f"REST Iteration {i + 1}: Failed - Expected {stream_length} messages, got {message_count}" + ) + + except httpx.RequestError as e: + print(f"REST Iteration {i + 1}: Request failed - {e}") + except Exception as e: + print(f"REST Iteration {i + 1}: Failed with error - {e}") + await asyncio.sleep(0.01) # Small delay + + if not timings: + return {"error": "No successful REST iterations."} + + avg_time = statistics.mean(timings) + std_dev_time = statistics.stdev(timings) if len(timings) > 1 else 0 + avg_first_response_latency = ( + statistics.mean(first_response_latencies) if first_response_latencies else -1 + ) + + total_bytes_per_iteration = statistics.mean(total_bytes_streamed_list) + throughput_mps = successful_iterations / sum(timings) if sum(timings) > 0 else 0 + throughput_bps = ( + sum(total_bytes_streamed_list) / sum(timings) if sum(timings) > 0 else 0 + ) + + return { + "successful_iterations": successful_iterations, + "avg_stream_time_s": avg_time, + "std_dev_stream_time_s": std_dev_time, + "avg_first_response_latency_s": avg_first_response_latency, + "throughput_streams_per_s": 1 / avg_time if avg_time > 0 else 0, + "throughput_msgs_per_s": throughput_mps * stream_length, + "throughput_bytes_per_s": throughput_bps, + "avg_bytes_per_stream": total_bytes_per_iteration, + } + + +# --- Main Execution --- +def display_results( + scenario: Dict[str, Any], grpc_results: Dict[str, Any], rest_results: Dict[str, Any] +): + print("\n--- Benchmark Scenario ---") + print(f" Payload Size: {scenario['payload_size']} bytes") + print(f" Stream Length: {scenario['stream_length']} messages") + print(f" Iterations: {scenario['iterations']}") + + print("\n--- gRPC Results ---") + if "error" in grpc_results: + print(f" Error: {grpc_results['error']}") + else: + print( + f" Successful Iterations: {grpc_results['successful_iterations']}/{scenario['iterations']}" + ) + print( + f" Avg. Stream Time: {grpc_results['avg_stream_time_s']:.4f} s (StdDev: {grpc_results['std_dev_stream_time_s']:.4f} s)" + ) + print( + f" Avg. First Response Latency: {grpc_results['avg_first_response_latency_s']:.4f} s" + ) + print( + f" Throughput (Streams/s): {grpc_results['throughput_streams_per_s']:.2f}" + ) + print(f" Throughput (Msgs/s): {grpc_results['throughput_msgs_per_s']:.2f}") + print(f" Throughput (Bytes/s): {grpc_results['throughput_bytes_per_s']:.2f}") + print(f" Avg. Bytes per Stream: {grpc_results['avg_bytes_per_stream']:.0f}") + + print("\n--- REST Results ---") + if "error" in rest_results: + print(f" Error: {rest_results['error']}") + else: + print( + f" Successful Iterations: {rest_results['successful_iterations']}/{scenario['iterations']}" + ) + print( + f" Avg. Stream Time: {rest_results['avg_stream_time_s']:.4f} s (StdDev: {rest_results['std_dev_stream_time_s']:.4f} s)" + ) + print( + f" Avg. First Response Latency: {rest_results['avg_first_response_latency_s']:.4f} s" + ) + print( + f" Throughput (Streams/s): {rest_results['throughput_streams_per_s']:.2f}" + ) + print(f" Throughput (Msgs/s): {rest_results['throughput_msgs_per_s']:.2f}") + print(f" Throughput (Bytes/s): {rest_results['throughput_bytes_per_s']:.2f}") + print(f" Avg. Bytes per Stream: {rest_results['avg_bytes_per_stream']:.0f}") + print("=" * 40) + + +async def main(): + parser = argparse.ArgumentParser(description="Benchmark gRPC vs REST streaming.") + parser.add_argument( + "--grpc_host", type=str, default=DEFAULT_HOST, help="gRPC server host." + ) + parser.add_argument( + "--grpc_port", type=int, default=DEFAULT_GRPC_PORT, help="gRPC server port." + ) + parser.add_argument( + "--rest_host", type=str, default=DEFAULT_HOST, help="REST server host." + ) + parser.add_argument( + "--rest_port", type=int, default=DEFAULT_REST_PORT, help="REST server port." + ) + parser.add_argument( + "--rest_endpoint", type=str, default="/stream", help="REST streaming endpoint." + ) + parser.add_argument( + "--iterations", + type=int, + default=DEFAULT_ITERATIONS, + help="Number of iterations per scenario.", + ) + parser.add_argument( + "--payload_sizes", + type=str, + default="10,1000,100000", + help="Comma-separated payload sizes in bytes.", + ) + parser.add_argument( + "--stream_lengths", + type=str, + default="10,100,500", + help="Comma-separated stream lengths (number of messages).", + ) + parser.add_argument( + "--skip_grpc", action="store_true", help="Skip gRPC benchmarks." + ) + parser.add_argument( + "--skip_rest", action="store_true", help="Skip REST benchmarks." + ) + + args = parser.parse_args() + + payload_sizes = [int(s) for s in args.payload_sizes.split(",")] + stream_lengths = [int(s) for s in args.stream_lengths.split(",")] + + print("Starting benchmark...") + print(f"Iterations per scenario: {args.iterations}") + print(f"Payload sizes: {payload_sizes}") + print(f"Stream lengths: {stream_lengths}") + print(f"gRPC Server: {args.grpc_host}:{args.grpc_port}") + print(f"REST Server: {args.rest_host}:{args.rest_port}{args.rest_endpoint}") + print( + "Important: Ensure both gRPC and REST servers are running and configured for streaming." + ) + print( + "The REST server must accept POST requests at the specified endpoint, expecting a JSON payload like" + ) + print( + "{'data': 'your_payload', 'stream_length': number_of_messages} and stream back line-delimited text responses." + ) + + scenarios = [ + {"payload_size": ps, "stream_length": sl, "iterations": args.iterations} + for ps in payload_sizes + for sl in stream_lengths + ] + + for scenario in scenarios: + grpc_results = {"error": "Skipped"} + rest_results = {"error": "Skipped"} + + if not args.skip_grpc: + grpc_results = await benchmark_grpc_stream( + args.grpc_host, + args.grpc_port, + scenario["payload_size"], + scenario["stream_length"], + scenario["iterations"], + ) + + if not args.skip_rest: + rest_results = await benchmark_rest_stream( + args.rest_host, + args.rest_port, + scenario["payload_size"], + scenario["stream_length"], + scenario["iterations"], + endpoint=args.rest_endpoint, + ) + + display_results(scenario, grpc_results, rest_results) + + +if __name__ == "__main__": + asyncio.run(main())