Skip to content

Commit f589168

Browse files
cbornetmdrxy
andauthored
refactor(core): use pytest style in TestGetBufferString (#32786)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
1 parent 5840dad commit f589168

File tree

1 file changed

+25
-35
lines changed

1 file changed

+25
-35
lines changed

libs/core/tests/unit_tests/test_messages.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import unittest
21
import uuid
32
from typing import Optional, Union
43

@@ -335,54 +334,45 @@ def test_ai_message_chunks() -> None:
335334
)
336335

337336

338-
class TestGetBufferString(unittest.TestCase):
339-
def setUp(self) -> None:
340-
self.human_msg = HumanMessage(content="human")
341-
self.ai_msg = AIMessage(content="ai")
342-
self.sys_msg = SystemMessage(content="system")
343-
self.func_msg = FunctionMessage(name="func", content="function")
344-
self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool")
345-
self.chat_msg = ChatMessage(role="Chat", content="chat")
346-
self.tool_calls_msg = AIMessage(content="tool")
337+
class TestGetBufferString:
338+
_HUMAN_MSG = HumanMessage(content="human")
339+
_AI_MSG = AIMessage(content="ai")
347340

348341
def test_empty_input(self) -> None:
349342
assert get_buffer_string([]) == ""
350343

351344
def test_valid_single_message(self) -> None:
352-
expected_output = f"Human: {self.human_msg.content}"
353-
assert get_buffer_string([self.human_msg]) == expected_output
345+
expected_output = "Human: human"
346+
assert get_buffer_string([self._HUMAN_MSG]) == expected_output
354347

355348
def test_custom_human_prefix(self) -> None:
356-
prefix = "H"
357-
expected_output = f"{prefix}: {self.human_msg.content}"
358-
assert get_buffer_string([self.human_msg], human_prefix="H") == expected_output
349+
expected_output = "H: human"
350+
assert get_buffer_string([self._HUMAN_MSG], human_prefix="H") == expected_output
359351

360352
def test_custom_ai_prefix(self) -> None:
361-
prefix = "A"
362-
expected_output = f"{prefix}: {self.ai_msg.content}"
363-
assert get_buffer_string([self.ai_msg], ai_prefix="A") == expected_output
353+
expected_output = "A: ai"
354+
assert get_buffer_string([self._AI_MSG], ai_prefix="A") == expected_output
364355

365356
def test_multiple_msg(self) -> None:
366357
msgs = [
367-
self.human_msg,
368-
self.ai_msg,
369-
self.sys_msg,
370-
self.func_msg,
371-
self.tool_msg,
372-
self.chat_msg,
373-
self.tool_calls_msg,
358+
self._HUMAN_MSG,
359+
self._AI_MSG,
360+
SystemMessage(content="system"),
361+
FunctionMessage(name="func", content="function"),
362+
ToolMessage(tool_call_id="tool_id", content="tool"),
363+
ChatMessage(role="Chat", content="chat"),
364+
AIMessage(content="tool"),
374365
]
375-
expected_output = "\n".join( # noqa: FLY002
376-
[
377-
"Human: human",
378-
"AI: ai",
379-
"System: system",
380-
"Function: function",
381-
"Tool: tool",
382-
"Chat: chat",
383-
"AI: tool",
384-
]
366+
expected_output = (
367+
"Human: human\n"
368+
"AI: ai\n"
369+
"System: system\n"
370+
"Function: function\n"
371+
"Tool: tool\n"
372+
"Chat: chat\n"
373+
"AI: tool"
385374
)
375+
386376
assert get_buffer_string(msgs) == expected_output
387377

388378

0 commit comments

Comments
 (0)