diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 7999d8a32161a..bac037fbce52c 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -361,7 +361,8 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # Modify the docstring to include a deprecation notice. if ( _alternative - and _alternative.split(".")[-1].lower() == _alternative.split(".")[-1] + and _alternative.rsplit(".", maxsplit=1)[-1].lower() + == _alternative.rsplit(".", maxsplit=1)[-1] ): _alternative = f":meth:`~{_alternative}`" elif _alternative: @@ -369,8 +370,8 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: if ( _alternative_import - and _alternative_import.split(".")[-1].lower() - == _alternative_import.split(".")[-1] + and _alternative_import.rsplit(".", maxsplit=1)[-1].lower() + == _alternative_import.rsplit(".", maxsplit=1)[-1] ): _alternative_import = f":meth:`~{_alternative_import}`" elif _alternative_import: @@ -474,7 +475,7 @@ def warn_deprecated( if not message: message = "" package_ = ( - package or name.split(".")[0].replace("_", "-") + package or name.split(".", maxsplit=1)[0].replace("_", "-") if "." in name else "LangChain" ) @@ -493,7 +494,7 @@ def warn_deprecated( message += f" and will be removed {removal}" if alternative_import: - alt_package = alternative_import.split(".")[0].replace("_", "-") + alt_package = alternative_import.split(".", maxsplit=1)[0].replace("_", "-") if alt_package == package_: message += f". Use {alternative_import} instead." else: diff --git a/libs/core/langchain_core/embeddings/fake.py b/libs/core/langchain_core/embeddings/fake.py index 73b605dfa4b98..cd644bf8c69a6 100644 --- a/libs/core/langchain_core/embeddings/fake.py +++ b/libs/core/langchain_core/embeddings/fake.py @@ -119,7 +119,8 @@ def _get_embedding(self, seed: int) -> list[float]: rng = np.random.default_rng(seed) return list(rng.normal(size=self.size)) - def _get_seed(self, text: str) -> int: + @staticmethod + def _get_seed(text: str) -> int: """Get a seed for the random generator, using the hash of the text.""" return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8 diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index f4ec3d48af939..a07115196bdd0 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -148,8 +148,6 @@ def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]: "type": key, key: block[key], } - else: - pass messages_to_trace.append(message_to_trace) return messages_to_trace diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index e864054564930..e0208d5f5a374 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -352,10 +352,7 @@ def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None: for chunk in self.tool_call_chunks: try: - if chunk["args"] is not None and chunk["args"] != "": - args_ = parse_partial_json(chunk["args"]) - else: - args_ = {} + args_ = parse_partial_json(chunk["args"]) if chunk["args"] else {} if isinstance(args_, dict): tool_calls.append( create_tool_call( diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 2f831d79d025e..8810d1ecf131b 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -179,9 +179,7 @@ def merge_content( elif merged and isinstance(merged[-1], str): merged[-1] += content # If second content is an empty string, treat as a no-op - elif content == "": - pass - else: + elif content: # Otherwise, add the second content as a new element of the list merged.append(content) return merged diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index 54577037759d9..0d1513e6d8622 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -46,11 +46,13 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch - def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]: + @staticmethod + def _get_schema(pydantic_object: type[TBaseModel]) -> dict[str, Any]: if issubclass(pydantic_object, pydantic.BaseModel): return pydantic_object.model_json_schema() return pydantic_object.schema() + @override def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index cc8ca4238cb50..b0d1ba4bfd4cf 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -155,6 +155,7 @@ def get_lc_namespace(cls) -> list[str]: """ return ["langchain", "output_parsers", "list"] + @override def get_format_instructions(self) -> str: """Return the format instructions for the comma-separated list output.""" return ( @@ -162,6 +163,7 @@ def get_format_instructions(self) -> str: "eg: `foo, bar, baz` or `foo,bar,baz`" ) + @override def parse(self, text: str) -> list[str]: """Parse the output of an LLM call. @@ -224,6 +226,7 @@ class MarkdownListOutputParser(ListOutputParser): pattern: str = r"^\s*[-*]\s([^\n]+)$" """The pattern to match a Markdown list item.""" + @override def get_format_instructions(self) -> str: """Return the format instructions for the Markdown list output.""" return "Your response should be a markdown list, eg: `- foo\n- bar\n- baz`" diff --git a/libs/core/langchain_core/output_parsers/string.py b/libs/core/langchain_core/output_parsers/string.py index c2194e8bddffd..35c9aab89f4a1 100644 --- a/libs/core/langchain_core/output_parsers/string.py +++ b/libs/core/langchain_core/output_parsers/string.py @@ -1,5 +1,7 @@ """String output parser.""" +from typing_extensions import override + from langchain_core.output_parsers.transform import BaseTransformOutputParser @@ -29,6 +31,7 @@ def _type(self) -> str: """Return the output parser type for serialization.""" return "default" + @override def parse(self, text: str) -> str: """Returns the input text with no changes.""" return text diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 6581c56902c49..370380d444be2 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -210,7 +210,7 @@ def invoke( if self.metadata: config["metadata"] = {**config["metadata"], **self.metadata} if self.tags: - config["tags"] = config["tags"] + self.tags + config["tags"] += self.tags return self._call_with_config( self._format_prompt_with_error_handling, input, diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 12357bf3a2721..b6f1cdae28f22 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -166,7 +166,7 @@ def mustache_schema( prefix = section_stack.pop() elif type_ in {"section", "inverted section"}: section_stack.append(prefix) - prefix = prefix + tuple(key.split(".")) + prefix += tuple(key.split(".")) fields[prefix] = False elif type_ in {"variable", "no escape"}: fields[prefix + tuple(key.split("."))] = True diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index a7655072d696b..17ebc55a5183a 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -155,7 +155,7 @@ def add_subgraph(edges: list[Edge], prefix: str) -> None: nonlocal mermaid_graph self_loop = len(edges) == 1 and edges[0].source == edges[0].target if prefix and not self_loop: - subgraph = prefix.split(":")[-1] + subgraph = prefix.rsplit(":", maxsplit=1)[-1] if subgraph in seen_subgraphs: msg = ( f"Found duplicate subgraph '{subgraph}' -- this likely means that " @@ -214,7 +214,7 @@ def add_subgraph(edges: list[Edge], prefix: str) -> None: # Add remaining subgraphs with edges for prefix, edges_ in edge_groups.items(): - if ":" in prefix or prefix == "": + if not prefix or ":" in prefix: continue add_subgraph(edges_, prefix) seen_subgraphs.add(prefix) diff --git a/libs/core/langchain_core/tracers/core.py b/libs/core/langchain_core/tracers/core.py index 54261a2c22489..9e680172a8f04 100644 --- a/libs/core/langchain_core/tracers/core.py +++ b/libs/core/langchain_core/tracers/core.py @@ -82,7 +82,7 @@ def __init__( """Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed.""" @abstractmethod - def _persist_run(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _persist_run(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: """Persist a run.""" @staticmethod @@ -108,7 +108,7 @@ def _get_stacktrace(error: BaseException) -> str: except: # noqa: E722 return msg - def _start_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # type: ignore[return] + def _start_trace(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # type: ignore[return] current_dotted_order = run.start_time.strftime("%Y%m%dT%H%M%S%fZ") + str(run.id) if run.parent_run_id: if parent := self.order_map.get(run.parent_run_id): @@ -538,7 +538,7 @@ def __copy__(self) -> _TracerCore: """Return self copied.""" return self - def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _end_trace(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """End a trace for a run. Args: @@ -546,7 +546,7 @@ def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noq """ return None - def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_run_create(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process a run upon creation. Args: @@ -554,7 +554,7 @@ def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # """ return None - def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_run_update(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process a run upon update. Args: @@ -562,7 +562,7 @@ def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # """ return None - def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_llm_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the LLM Run upon start. Args: @@ -575,7 +575,7 @@ def _on_llm_new_token( run: Run, # noqa: ARG002 token: str, # noqa: ARG002 chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002 - ) -> Union[None, Coroutine[Any, Any, None]]: + ) -> Union[Coroutine[Any, Any, None], None]: """Process new LLM token. Args: @@ -585,7 +585,7 @@ def _on_llm_new_token( """ return None - def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_llm_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the LLM Run. Args: @@ -593,7 +593,7 @@ def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # no """ return None - def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_llm_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the LLM Run upon error. Args: @@ -601,7 +601,7 @@ def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # """ return None - def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_chain_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the Chain Run upon start. Args: @@ -609,7 +609,7 @@ def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: """ return None - def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_chain_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the Chain Run. Args: @@ -617,7 +617,7 @@ def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # """ return None - def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_chain_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the Chain Run upon error. Args: @@ -625,7 +625,7 @@ def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: """ return None - def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_tool_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the Tool Run upon start. Args: @@ -633,7 +633,7 @@ def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # """ return None - def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_tool_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the Tool Run. Args: @@ -641,7 +641,7 @@ def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # n """ return None - def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_tool_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the Tool Run upon error. Args: @@ -649,7 +649,7 @@ def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # """ return None - def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_chat_model_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the Chat Model Run upon start. Args: @@ -657,7 +657,7 @@ def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None """ return None - def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_retriever_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the Retriever Run upon start. Args: @@ -665,7 +665,7 @@ def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None] """ return None - def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_retriever_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the Retriever Run. Args: @@ -673,7 +673,7 @@ def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: """ return None - def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 + def _on_retriever_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 """Process the Retriever Run upon error. Args: diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py index 7a0477c658e86..e319fb0891113 100644 --- a/libs/core/langchain_core/utils/aiter.py +++ b/libs/core/langchain_core/utils/aiter.py @@ -37,7 +37,7 @@ # before 3.10, the builtin anext() was not available def py_anext( iterator: AsyncIterator[T], default: Union[T, Any] = _no_default -) -> Awaitable[Union[T, None, Any]]: +) -> Awaitable[Union[T, Any, None]]: """Pure-Python implementation of anext() for testing purposes. Closely matches the builtin anext() C implementation. diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index fd2ae3591763d..9c8a93cacc8ef 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -82,7 +82,7 @@ def l_sa_check( """ # If there is a newline, or the previous tag was a standalone if literal.find("\n") != -1 or is_standalone: - padding = literal.split("\n")[-1] + padding = literal.rsplit("\n", maxsplit=1)[-1] # If all the characters since the last newline are spaces # Then the next tag could be a standalone diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 1a41b63dc285c..122f1c0af2de7 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -134,7 +134,7 @@ def guard_import( try: module = importlib.import_module(module_name, package) except (ImportError, ModuleNotFoundError) as e: - pip_name = pip_name or module_name.split(".")[0].replace("_", "-") + pip_name = pip_name or module_name.split(".", maxsplit=1)[0].replace("_", "-") msg = ( f"Could not import {module_name} python package. " f"Please install it with `pip install {pip_name}`." diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 1ea42cee12324..8f48ba5d5e858 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -83,27 +83,39 @@ docstring-code-format = true [tool.ruff.lint] select = [ "ALL",] ignore = [ - "C90", # McCabe complexity - "COM812", # Messes with the formatter - "FA100", # Can't activate since we exclude UP007 for now - "FIX002", # Line contains TODO - "ISC001", # Messes with the formatter + "C90", # McCabe complexity + "COM812", # Messes with the formatter + "CPY", # No copyright + "FIX002", # Line contains TODO + "ISC001", # Messes with the formatter + "C90", # McCabe complexity + "COM812", # Messes with the formatter + "CPY", # No copyright + "FIX002", # Line contains TODO + "ISC001", # Messes with the formatter "PERF203", # Rarely useful - "PLC0414", # Enable re-export - "PLR09", # Too many something (arg, statements, etc) - "RUF012", # Doesn't play well with Pydantic - "TC001", # Doesn't play well with Pydantic - "TC002", # Doesn't play well with Pydantic - "TC003", # Doesn't play well with Pydantic - "TD002", # Missing author in TODO - "TD003", # Missing issue link in TODO + "PLR09", # Too many something (arg, statements, etc) + "RUF012", # Doesn't play well with Pydantic + "TC001", # Doesn't play well with Pydantic + "TC002", # Doesn't play well with Pydantic + "TC003", # Doesn't play well with Pydantic + "TD002", # Missing author in TODO + "TD003", # Missing issue link in TODO + "PLR09", # Too many something (arg, statements, etc) + "RUF012", # Doesn't play well with Pydantic + "TC001", # Doesn't play well with Pydantic + "TC002", # Doesn't play well with Pydantic + "TC003", # Doesn't play well with Pydantic + "TD002", # Missing author in TODO + "TD003", # Missing issue link in TODO # TODO rules - "ANN401", - "BLE", - "ERA", - "PLC0415", - "PLR2004", + "ANN401", # No Any types + "BLE", # Blind exceptions + "DOC", # Docstrings (preview) + "ERA", # No commented-out code + "PLC0415", # Imports outside top level + "PLR2004", # Comparison to magic number ] unfixable = ["PLW1510",] diff --git a/libs/core/tests/unit_tests/document_loaders/test_base.py b/libs/core/tests/unit_tests/document_loaders/test_base.py index 87ab70e9e092d..c7d5811e18835 100644 --- a/libs/core/tests/unit_tests/document_loaders/test_base.py +++ b/libs/core/tests/unit_tests/document_loaders/test_base.py @@ -35,6 +35,7 @@ def lazy_parse(self, blob: Blob) -> Iterator[Document]: def test_default_lazy_load() -> None: class FakeLoader(BaseLoader): + @override def load(self) -> list[Document]: return [ Document(page_content="foo"), @@ -57,6 +58,7 @@ class FakeLoader(BaseLoader): async def test_default_aload() -> None: class FakeLoader(BaseLoader): + @override def lazy_load(self) -> Iterator[Document]: yield from [ Document(page_content="foo"), diff --git a/libs/core/tests/unit_tests/example_selectors/test_base.py b/libs/core/tests/unit_tests/example_selectors/test_base.py index 54793627987e0..c49e7745f0718 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_base.py +++ b/libs/core/tests/unit_tests/example_selectors/test_base.py @@ -1,5 +1,7 @@ from typing import Optional +from typing_extensions import override + from langchain_core.example_selectors import BaseExampleSelector @@ -10,6 +12,7 @@ def __init__(self) -> None: def add_example(self, example: dict[str, str]) -> None: self.example = example + @override def select_examples(self, input_variables: dict[str, str]) -> list[dict]: return [input_variables] diff --git a/libs/core/tests/unit_tests/indexing/test_in_memory_indexer.py b/libs/core/tests/unit_tests/indexing/test_in_memory_indexer.py index b16823a762626..decf9f67ac76c 100644 --- a/libs/core/tests/unit_tests/indexing/test_in_memory_indexer.py +++ b/libs/core/tests/unit_tests/indexing/test_in_memory_indexer.py @@ -7,6 +7,7 @@ AsyncDocumentIndexTestSuite, DocumentIndexerTestSuite, ) +from typing_extensions import override from langchain_core.documents import Document from langchain_core.indexing.base import DocumentIndex @@ -17,6 +18,7 @@ class TestDocumentIndexerTestSuite(DocumentIndexerTestSuite): @pytest.fixture + @override def index(self) -> Generator[DocumentIndex, None, None]: yield InMemoryDocumentIndex() # noqa: PT022 @@ -24,6 +26,7 @@ def index(self) -> Generator[DocumentIndex, None, None]: class TestAsyncDocumentIndexerTestSuite(AsyncDocumentIndexTestSuite): # Something funky is going on with mypy and async pytest fixture @pytest.fixture + @override async def index(self) -> AsyncGenerator[DocumentIndex, None]: yield InMemoryDocumentIndex() # noqa: PT022 diff --git a/libs/core/tests/unit_tests/indexing/test_indexing.py b/libs/core/tests/unit_tests/indexing/test_indexing.py index cc579d4d032ca..a3e88e7535be4 100644 --- a/libs/core/tests/unit_tests/indexing/test_indexing.py +++ b/libs/core/tests/unit_tests/indexing/test_indexing.py @@ -501,15 +501,14 @@ def test_incremental_fails_with_bad_source_ids( with pytest.raises( ValueError, match="Source id key is required when cleanup mode is " - "incremental or scoped_full.", + "incremental or scoped_full", ): # Should raise an error because no source id function was specified index(loader, record_manager, vector_store, cleanup="incremental") with pytest.raises( ValueError, - match="Source ids are required when cleanup mode " - "is incremental or scoped_full.", + match="Source ids are required when cleanup mode is incremental or scoped_full", ): # Should raise an error because no source id function was specified index( @@ -545,7 +544,7 @@ async def test_aincremental_fails_with_bad_source_ids( with pytest.raises( ValueError, match="Source id key is required when cleanup mode " - "is incremental or scoped_full.", + "is incremental or scoped_full", ): # Should raise an error because no source id function was specified await aindex( @@ -557,8 +556,7 @@ async def test_aincremental_fails_with_bad_source_ids( with pytest.raises( ValueError, - match="Source ids are required when cleanup mode " - "is incremental or scoped_full.", + match="Source ids are required when cleanup mode is incremental or scoped_full", ): # Should raise an error because no source id function was specified await aindex( @@ -838,15 +836,14 @@ def test_scoped_full_fails_with_bad_source_ids( with pytest.raises( ValueError, match="Source id key is required when cleanup mode " - "is incremental or scoped_full.", + "is incremental or scoped_full", ): # Should raise an error because no source id function was specified index(loader, record_manager, vector_store, cleanup="scoped_full") with pytest.raises( ValueError, - match="Source ids are required when cleanup mode " - "is incremental or scoped_full.", + match="Source ids are required when cleanup mode is incremental or scoped_full", ): # Should raise an error because no source id function was specified index( @@ -882,15 +879,14 @@ async def test_ascoped_full_fails_with_bad_source_ids( with pytest.raises( ValueError, match="Source id key is required when cleanup mode " - "is incremental or scoped_full.", + "is incremental or scoped_full", ): # Should raise an error because no source id function was specified await aindex(loader, arecord_manager, vector_store, cleanup="scoped_full") with pytest.raises( ValueError, - match="Source ids are required when cleanup mode " - "is incremental or scoped_full.", + match="Source ids are required when cleanup mode is incremental or scoped_full", ): # Should raise an error because no source id function was specified await aindex( diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index f873deb0877a7..1b1fee580f411 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -173,7 +173,7 @@ def test_create_system_message_prompt_list_template_partial_variables_not_null() """ with pytest.raises( - ValueError, match="Partial variables are not supported for list of templates." + ValueError, match="Partial variables are not supported for list of templates" ): _ = SystemMessagePromptTemplate.from_template( template=[graph_creator_content1, graph_creator_content2], @@ -568,7 +568,7 @@ def test_chat_prompt_template_append_and_extend() -> None: def test_convert_to_message_is_strict() -> None: """Verify that _convert_to_message is strict.""" - with pytest.raises(ValueError, match="Unexpected message type: meow."): + with pytest.raises(ValueError, match="Unexpected message type: meow"): # meow does not correspond to a valid message type. # this test is here to ensure that functionality to interpret `meow` # as a role is NOT added. @@ -758,7 +758,7 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None: async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None: """Verify that we cannot pass `path` for an image as a variable.""" - in_mem = "base64mem" + in_mem_ = "base64mem" template = ChatPromptTemplate.from_messages( [ @@ -781,23 +781,27 @@ async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None: ) with pytest.raises( ValueError, - match="Loading images from 'path' has been removed " - "as of 0.3.15 for security reasons.", + match=re.escape( + "Loading images from 'path' has been removed as of 0.3.15 " + "for security reasons." + ), ): template.format_messages( name="R2D2", - in_mem=in_mem, + in_mem=in_mem_, file_path="some/path", ) with pytest.raises( ValueError, - match="Loading images from 'path' has been removed " - "as of 0.3.15 for security reasons.", + match=re.escape( + "Loading images from 'path' has been removed as of 0.3.15 " + "for security reasons." + ), ): await template.aformat_messages( name="R2D2", - in_mem=in_mem, + in_mem=in_mem_, file_path="some/path", ) @@ -900,7 +904,7 @@ def test_chat_prompt_message_dict() -> None: with pytest.raises(ValueError, match="Invalid template: False"): ChatPromptTemplate([{"role": "system", "content": False}]) - with pytest.raises(ValueError, match="Unexpected message type: foo."): + with pytest.raises(ValueError, match="Unexpected message type: foo"): ChatPromptTemplate([{"role": "foo", "content": "foo"}]) diff --git a/libs/core/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py index 76d753147758d..3e230c54bbb31 100644 --- a/libs/core/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -49,14 +49,14 @@ def test_loading_from_json() -> None: def test_loading_jinja_from_json() -> None: """Test that loading jinja2 format prompts from JSON raises ValueError.""" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json" - with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"): + with pytest.raises(ValueError, match=r".*can lead to arbitrary code execution.*"): load_prompt(prompt_path) def test_loading_jinja_from_yaml() -> None: """Test that loading jinja2 format prompts from YAML raises ValueError.""" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml" - with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"): + with pytest.raises(ValueError, match=r".*can lead to arbitrary code execution.*"): load_prompt(prompt_path) diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index 4f5cbc9ab792f..9758f1d1d1ea3 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -4,6 +4,7 @@ import pytest from pydantic import BaseModel +from typing_extensions import override from langchain_core.language_models import FakeListChatModel from langchain_core.load.dump import dumps @@ -27,6 +28,7 @@ def _fake_runnable( class FakeStructuredChatModel(FakeListChatModel): """Fake ChatModel for testing purposes.""" + @override def with_structured_output( self, schema: Union[dict, type[BaseModel]], **kwargs: Any ) -> Runnable: diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py index b638ac78f57ff..073b972fe0fa4 100644 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -368,7 +368,7 @@ def test_runnable_context_seq_key_order() -> None: with pytest.raises( ValueError, - match="Context setter for key foo must be defined after all getters.", + match="Context setter for key foo must be defined after all getters", ): seq.invoke("foo") diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 5c33c5a99c4a2..530ce185743ac 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -3799,12 +3799,14 @@ def is_lc_serializable(cls) -> bool: """Return whether or not the class is serializable.""" return True + @override def get_format_instructions(self) -> str: return ( "Your response should be a list of comma separated values, " "eg: `foo, bar, baz`" ) + @override def parse(self, text: str) -> list[str]: """Parse the output of an LLM call.""" return text.strip().split(", ") diff --git a/libs/core/tests/unit_tests/stores/test_in_memory.py b/libs/core/tests/unit_tests/stores/test_in_memory.py index cc8ec684fbe72..7a0036b578782 100644 --- a/libs/core/tests/unit_tests/stores/test_in_memory.py +++ b/libs/core/tests/unit_tests/stores/test_in_memory.py @@ -3,6 +3,7 @@ BaseStoreAsyncTests, BaseStoreSyncTests, ) +from typing_extensions import override from langchain_core.stores import InMemoryStore @@ -10,20 +11,24 @@ # Check against standard tests class TestSyncInMemoryStore(BaseStoreSyncTests): @pytest.fixture + @override def kv_store(self) -> InMemoryStore: return InMemoryStore() @pytest.fixture + @override def three_values(self) -> tuple[str, str, str]: return "value1", "value2", "value3" class TestAsyncInMemoryStore(BaseStoreAsyncTests): @pytest.fixture + @override async def kv_store(self) -> InMemoryStore: return InMemoryStore() @pytest.fixture + @override def three_values(self) -> tuple[str, str, str]: return "value1", "value2", "value3" diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index b1df495523280..9f0f43474e9ba 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -202,7 +202,7 @@ def test_chat_message_chunks() -> None: ) with pytest.raises( - ValueError, match="Cannot concatenate ChatMessageChunks with different roles." + ValueError, match="Cannot concatenate ChatMessageChunks with different roles" ): ChatMessageChunk(role="User", content="I am") + ChatMessageChunk( role="Assistant", content=" indeed." @@ -311,7 +311,7 @@ def test_function_message_chunks() -> None: with pytest.raises( ValueError, - match="Cannot concatenate FunctionMessageChunks with different names.", + match="Cannot concatenate FunctionMessageChunks with different names", ): FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk( name="bye", content=" indeed." @@ -327,7 +327,7 @@ def test_ai_message_chunks() -> None: with pytest.raises( ValueError, - match="Cannot concatenate AIMessageChunks with different example values.", + match="Cannot concatenate AIMessageChunks with different example values", ): AIMessageChunk(example=True, content="I am") + AIMessageChunk( example=False, content=" indeed." @@ -339,7 +339,7 @@ class TestGetBufferString: _AI_MSG = AIMessage(content="ai") def test_empty_input(self) -> None: - assert get_buffer_string([]) == "" + assert not get_buffer_string([]) def test_valid_single_message(self) -> None: expected_output = "Human: human" @@ -1052,7 +1052,7 @@ def test_message_text() -> None: # content dict types: [text], [not text], [no type] assert HumanMessage(content="foo").text() == "foo" - assert AIMessage(content=[]).text() == "" + assert not AIMessage(content=[]).text() assert AIMessage(content=["foo", "bar"]).text() == "foobar" assert ( AIMessage( @@ -1092,14 +1092,11 @@ def test_message_text() -> None: assert ( AIMessage(content=[{"text": "hi there"}, "hi"]).text() == "hi" ) # missing type: text - assert AIMessage(content=[{"type": "nottext", "text": "hi"}]).text() == "" - assert AIMessage(content=[]).text() == "" - assert ( - AIMessage( - content="", tool_calls=[create_tool_call(name="a", args={"b": 1}, id=None)] - ).text() - == "" - ) + assert not AIMessage(content=[{"type": "nottext", "text": "hi"}]).text() + assert not AIMessage(content=[]).text() + assert not AIMessage( + content="", tool_calls=[create_tool_call(name="a", args={"b": 1}, id=None)] + ).text() def test_is_data_content_block() -> None: diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 72c6a5a387cfb..63107361224c4 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -24,7 +24,7 @@ from pydantic import BaseModel, Field, ValidationError from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 -from typing_extensions import TypedDict +from typing_extensions import TypedDict, override from langchain_core import tools from langchain_core.callbacks import ( @@ -119,6 +119,7 @@ class _MockStructuredTool(BaseTool): args_schema: type[BaseModel] = _MockSchema description: str = "A Structured Tool" + @override def _run(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: return f"{arg1} {arg2} {arg3}" @@ -146,6 +147,7 @@ class _MisAnnotatedTool(BaseTool): args_schema: BaseModel = _MockSchema # type: ignore[assignment] description: str = "A Structured Tool" + @override def _run( self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None ) -> str: @@ -165,6 +167,7 @@ class _ForwardRefAnnotatedTool(BaseTool): args_schema: "type[BaseModel]" = _MockSchema description: str = "A Structured Tool" + @override def _run(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: return f"{arg1} {arg2} {arg3}" @@ -182,6 +185,7 @@ class _ForwardRefAnnotatedTool(BaseTool): args_schema: type[_MockSchema] = _MockSchema description: str = "A Structured Tool" + @override def _run(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: return f"{arg1} {arg2} {arg3}" @@ -237,6 +241,7 @@ class _SingleArgToolWithKwargs(BaseTool): name: str = "single_arg_tool" description: str = "A single arged tool with kwargs" + @override def _run( self, some_arg: str, @@ -260,6 +265,7 @@ class _VarArgToolWithKwargs(BaseTool): name: str = "single_arg_tool" description: str = "A single arged tool with kwargs" + @override def _run( self, *args: Any, @@ -396,9 +402,11 @@ class _MockSimpleTool(BaseTool): name: str = "simple_tool" description: str = "A Simple Tool" + @override def _run(self, tool_input: str) -> str: return f"{tool_input}" + @override async def _arun(self, tool_input: str) -> str: raise NotImplementedError @@ -705,7 +713,7 @@ def search_api(query: str) -> str: class MyTool(BaseModel): foo: str - assert MyTool.description == "" # type: ignore[attr-defined] + assert not MyTool.description # type: ignore[attr-defined] def test_create_tool_positional_args() -> None: @@ -909,9 +917,11 @@ def _parse_input( ) -> Union[str, dict[str, Any]]: raise NotImplementedError + @override def _run(self) -> str: return "dummy" + @override async def _arun(self) -> str: return "dummy" @@ -975,9 +985,11 @@ def _parse_input( ) -> Union[str, dict[str, Any]]: raise NotImplementedError + @override def _run(self) -> str: return "dummy" + @override async def _arun(self) -> str: return "dummy" @@ -1096,11 +1108,13 @@ class FooBase(BaseTool): name: str = "Foo" description: str = "Foo" + @override def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any: return assert_bar(bar, bar_config) class AFooBase(FooBase): + @override async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any: return assert_bar(bar, bar_config) @@ -1121,6 +1135,7 @@ def test_tool_pass_config(tool: BaseTool) -> None: class FooBaseNonPickleable(FooBase): + @override def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any: return True @@ -1332,7 +1347,7 @@ def foo4(bar: str, baz: int) -> str: return bar for func in {foo3, foo4}: - with pytest.raises(ValueError, match="Found invalid Google-Style docstring."): + with pytest.raises(ValueError, match="Found invalid Google-Style docstring"): _ = tool(func, parse_docstring=True) def foo5(bar: str, baz: int) -> str: # noqa: D417 @@ -1345,7 +1360,7 @@ def foo5(bar: str, baz: int) -> str: # noqa: D417 return bar with pytest.raises( - ValueError, match="Arg banana in docstring not found in function signature." + ValueError, match="Arg banana in docstring not found in function signature" ): _ = tool(foo5, parse_docstring=True) @@ -1404,10 +1419,11 @@ class _MockStructuredToolWithRawOutput(BaseTool): description: str = "A Structured Tool" response_format: Literal["content_and_artifact"] = "content_and_artifact" + @override def _run( self, arg1: int, - arg2: bool, # noqa: FBT001 + arg2: bool, arg3: Optional[dict] = None, ) -> tuple[str, dict]: return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} @@ -1554,6 +1570,7 @@ class InjectedTool(BaseTool): name: str = "foo" description: str = "foo." + @override def _run(self, x: int, y: Annotated[str, InjectedToolArg]) -> Any: """Foo. @@ -1578,6 +1595,7 @@ class InjectedToolWithSchema(BaseTool): description: str = "foo." args_schema: type[BaseModel] = fooSchema + @override def _run(self, x: int, y: str) -> Any: return y @@ -1743,6 +1761,7 @@ class InheritedInjectedArgTool(BaseTool): description: str = "foo." args_schema: type[BaseModel] = FooSchema + @override def _run(self, x: int, y: str) -> Any: return y @@ -1854,6 +1873,7 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None: class SomeTool(BaseTool): args_schema: type[pydantic_model] = pydantic_model + @override def _run(self, *args: Any, **kwargs: Any) -> str: return "foo" @@ -1913,6 +1933,7 @@ class SomeTool(BaseTool): # for pydantic 2! args_schema: type[BaseModel] = Foo + @override def _run(self, *args: Any, **kwargs: Any) -> str: return "foo" @@ -2136,6 +2157,7 @@ def my_tool(val: int, other_val: Annotated[dict, "my annotation"]) -> str: def test_create_retriever_tool() -> None: class MyRetriever(BaseRetriever): + @override def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> list[Document]: @@ -2364,6 +2386,7 @@ class MyTool(BaseTool): name: str = "MyTool" description: str = "a tool" + @override def _run( self, x: str, diff --git a/libs/core/tests/unit_tests/utils/test_env.py b/libs/core/tests/unit_tests/utils/test_env.py index ea49268cee8ce..0d52e8534c68f 100644 --- a/libs/core/tests/unit_tests/utils/test_env.py +++ b/libs/core/tests/unit_tests/utils/test_env.py @@ -55,7 +55,7 @@ def test_get_from_dict_or_env() -> None: ValueError, match="Did not find not exists, " "please add an environment variable `__SOME_KEY_IN_ENV` which contains it, " - "or pass `not exists` as a named parameter.", + "or pass `not exists` as a named parameter", ): assert ( get_from_dict_or_env( diff --git a/libs/core/tests/unit_tests/utils/test_html.py b/libs/core/tests/unit_tests/utils/test_html.py index a6332e4b606b1..7d6c00dab694b 100644 --- a/libs/core/tests/unit_tests/utils/test_html.py +++ b/libs/core/tests/unit_tests/utils/test_html.py @@ -36,31 +36,31 @@ def test_find_all_links_multiple() -> None: def test_find_all_links_ignore_suffix() -> None: html = 'href="foobar{suffix}"' - for suffix in SUFFIXES_TO_IGNORE: - actual = find_all_links(html.format(suffix=suffix)) + for suffix_ in SUFFIXES_TO_IGNORE: + actual = find_all_links(html.format(suffix=suffix_)) assert actual == [] # Don't ignore if pattern doesn't occur at end of link. html = 'href="foobar{suffix}more"' - for suffix in SUFFIXES_TO_IGNORE: - actual = find_all_links(html.format(suffix=suffix)) - assert actual == [f"foobar{suffix}more"] + for suffix_ in SUFFIXES_TO_IGNORE: + actual = find_all_links(html.format(suffix=suffix_)) + assert actual == [f"foobar{suffix_}more"] def test_find_all_links_ignore_prefix() -> None: html = 'href="{prefix}foobar"' - for prefix in PREFIXES_TO_IGNORE: - actual = find_all_links(html.format(prefix=prefix)) + for prefix_ in PREFIXES_TO_IGNORE: + actual = find_all_links(html.format(prefix=prefix_)) assert actual == [] # Don't ignore if pattern doesn't occur at beginning of link. html = 'href="foobar{prefix}more"' - for prefix in PREFIXES_TO_IGNORE: + for prefix_ in PREFIXES_TO_IGNORE: # Pound signs are split on when not prefixes. - if prefix == "#": + if prefix_ == "#": continue - actual = find_all_links(html.format(prefix=prefix)) - assert actual == [f"foobar{prefix}more"] + actual = find_all_links(html.format(prefix=prefix_)) + assert actual == [f"foobar{prefix_}more"] def test_find_all_links_drop_fragment() -> None: diff --git a/libs/core/tests/unit_tests/utils/test_strings.py b/libs/core/tests/unit_tests/utils/test_strings.py index 2162fb3efe8c2..85c328a4758ed 100644 --- a/libs/core/tests/unit_tests/utils/test_strings.py +++ b/libs/core/tests/unit_tests/utils/test_strings.py @@ -24,7 +24,7 @@ def test_sanitize_for_postgres() -> None: assert sanitize_for_postgres(clean_text) == clean_text # Test empty string - assert sanitize_for_postgres("") == "" + assert not sanitize_for_postgres("") # Test with multiple consecutive NUL bytes text_with_multiple_nuls = "Hello\x00\x00\x00world" diff --git a/libs/core/tests/unit_tests/utils/test_usage.py b/libs/core/tests/unit_tests/utils/test_usage.py index 04c89b0537ec6..1ad3500d6d830 100644 --- a/libs/core/tests/unit_tests/utils/test_usage.py +++ b/libs/core/tests/unit_tests/utils/test_usage.py @@ -30,7 +30,7 @@ def test_dict_int_op_max_depth_exceeded() -> None: left = {"a": {"b": {"c": 1}}} right = {"a": {"b": {"c": 2}}} with pytest.raises( - ValueError, match="max_depth=2 exceeded, unable to combine dicts." + ValueError, match="max_depth=2 exceeded, unable to combine dicts" ): _dict_int_op(left, right, operator.add, max_depth=2) @@ -40,6 +40,6 @@ def test_dict_int_op_invalid_types() -> None: right = {"a": 2, "b": 3} with pytest.raises( ValueError, - match="Only dict and int values are supported.", + match="Only dict and int values are supported", ): _dict_int_op(left, right, operator.add) diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index c38b951da6558..51967b82a1974 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -96,7 +96,7 @@ def test_check_package_version( TypeError, match=( "Additional kwargs key a already exists in left dict and value " - "has unsupported type .+tuple.+." + r"has unsupported type .+tuple.+." ), ), ), @@ -136,7 +136,7 @@ def test_merge_dicts( ( {"type": "foo"}, {"type": "bar"}, - pytest.raises(ValueError, match="Unable to merge."), + pytest.raises(ValueError, match="Unable to merge"), ), ], )