Skip to content

Commit cc26b49

Browse files
committed
simpler type hints when plotly is not installed
1 parent 8e5a7f2 commit cc26b49

File tree

4 files changed

+39
-55
lines changed

4 files changed

+39
-55
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ test:
44
test-no-plotly:
55
uv sync --extra test
66
uv pip uninstall plotly
7-
pytest tests/test_other.py -k plotly
7+
pytest tests/test_other.py -k plotly --pdb
88
uv sync --extra test
99
pytest tests/test_other.py -k plotly
1010

bertopic/_bertopic.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from collections import defaultdict, Counter
2727
from scipy.sparse import csr_matrix
2828
from scipy.cluster import hierarchy as sch
29+
from importlib.util import find_spec
2930

3031
# Typing
3132
import sys
@@ -34,7 +35,19 @@
3435
from typing import Literal
3536
else:
3637
from typing_extensions import Literal
37-
from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable
38+
from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable, TYPE_CHECKING
39+
40+
# Plotting
41+
if find_spec("plotly") is None:
42+
from bertopic._utils import MockPlotlyModule
43+
plotting = MockPlotlyModule()
44+
45+
else:
46+
from bertopic import plotting
47+
if TYPE_CHECKING:
48+
import plotly.graph_objs as go
49+
import matplotlib.figure as fig
50+
3851

3952
# Models
4053
try:
@@ -72,23 +85,9 @@
7285
)
7386
import bertopic._save_utils as save_utils
7487

75-
7688
logger = MyLogger()
7789
logger.configure("WARNING")
7890

79-
try:
80-
from bertopic import plotting
81-
import plotly.graph_objects as go
82-
83-
except ModuleNotFoundError as e:
84-
if "No module named 'plotly'" in str(e):
85-
logger.warning("Plotly is not installed. Please install it to use the plotting functions.")
86-
from bertopic._utils import mock_plotly_go as go, MockPlotting
87-
88-
plotting = MockPlotting(logger)
89-
else:
90-
raise ModuleNotFoundError(e)
91-
9291

9392
class BERTopic:
9493
"""BERTopic is a topic modeling technique that leverages BERT embeddings and
@@ -2415,7 +2414,7 @@ def visualize_topics(
24152414
title: str = "<b>Intertopic Distance Map</b>",
24162415
width: int = 650,
24172416
height: int = 650,
2418-
) -> go.Figure:
2417+
) -> "go.Figure":
24192418
"""Visualize topics, their sizes, and their corresponding words.
24202419
24212420
This visualization is highly inspired by LDAvis, a great visualization
@@ -2473,7 +2472,7 @@ def visualize_documents(
24732472
title: str = "<b>Documents and Topics</b>",
24742473
width: int = 1200,
24752474
height: int = 750,
2476-
) -> go.Figure:
2475+
) -> "go.Figure":
24772476
"""Visualize documents and their topics in 2D.
24782477
24792478
Arguments:
@@ -2575,7 +2574,7 @@ def visualize_document_datamap(
25752574
topic_prefix: bool = False,
25762575
datamap_kwds: dict = {},
25772576
int_datamap_kwds: dict = {},
2578-
):
2577+
) -> "fig.Figure":
25792578
"""Visualize documents and their topics in 2D as a static plot for publication using
25802579
DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best
25812580
to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model.
@@ -2686,7 +2685,7 @@ def visualize_hierarchical_documents(
26862685
title: str = "<b>Hierarchical Documents and Topics</b>",
26872686
width: int = 1200,
26882687
height: int = 750,
2689-
) -> go.Figure:
2688+
) -> "go.Figure":
26902689
"""Visualize documents and their topics in 2D at different levels of hierarchy.
26912690
26922691
Arguments:
@@ -2798,7 +2797,7 @@ def visualize_term_rank(
27982797
title: str = "<b>Term score decline per Topic</b>",
27992798
width: int = 800,
28002799
height: int = 500,
2801-
) -> go.Figure:
2800+
) -> "go.Figure":
28022801
"""Visualize the ranks of all terms across all topics.
28032802
28042803
Each topic is represented by a set of words. These words, however,
@@ -2863,7 +2862,7 @@ def visualize_topics_over_time(
28632862
title: str = "<b>Topics over Time</b>",
28642863
width: int = 1250,
28652864
height: int = 450,
2866-
) -> go.Figure:
2865+
) -> "go.Figure":
28672866
"""Visualize topics over time.
28682867
28692868
Arguments:
@@ -2919,7 +2918,7 @@ def visualize_topics_per_class(
29192918
title: str = "<b>Topics per Class</b>",
29202919
width: int = 1250,
29212920
height: int = 900,
2922-
) -> go.Figure:
2921+
) -> "go.Figure":
29232922
"""Visualize topics per class.
29242923
29252924
Arguments:
@@ -2973,7 +2972,7 @@ def visualize_distribution(
29732972
title: str = "<b>Topic Probability Distribution</b>",
29742973
width: int = 800,
29752974
height: int = 600,
2976-
) -> go.Figure:
2975+
) -> "go.Figure":
29772976
"""Visualize the distribution of topic probabilities.
29782977
29792978
Arguments:
@@ -3080,7 +3079,7 @@ def visualize_hierarchy(
30803079
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
30813080
distance_function: Callable[[csr_matrix], csr_matrix] = None,
30823081
color_threshold: int = 1,
3083-
) -> go.Figure:
3082+
) -> "go.Figure":
30843083
"""Visualize a hierarchical structure of the topics.
30853084
30863085
A ward linkage function is used to perform the
@@ -3176,7 +3175,7 @@ def visualize_heatmap(
31763175
title: str = "<b>Similarity Matrix</b>",
31773176
width: int = 800,
31783177
height: int = 800,
3179-
) -> go.Figure:
3178+
) -> "go.Figure":
31803179
"""Visualize a heatmap of the topic's similarity matrix.
31813180
31823181
Based on the cosine similarity matrix between c-TF-IDFs or semantic embeddings of the topics,
@@ -3236,7 +3235,7 @@ def visualize_barchart(
32363235
width: int = 250,
32373236
height: int = 250,
32383237
autoscale: bool = False,
3239-
) -> go.Figure:
3238+
) -> "go.Figure":
32403239
"""Visualize a barchart of selected topics.
32413240
32423241
Arguments:

bertopic/_utils.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Iterable
55
from scipy.sparse import csr_matrix
66
from scipy.spatial.distance import squareform
7-
from typing import Optional, Union, Tuple
7+
from typing import Optional, Union, Tuple, Any
88

99

1010
class MyLogger:
@@ -228,26 +228,11 @@ def to_ndarray(array: Union[np.ndarray, csr_matrix]) -> np.ndarray:
228228
return to_ndarray(repr_) if output_ndarray else repr_, ctfidf_used
229229

230230

231-
# Visualization mocks in case plotly is not installed
232-
class MockPlotting:
233-
"""Mock plotting module when plotly is not installed."""
231+
class MockPlotlyModule:
232+
"""Mock module that raises an error when plotly functions are called."""
234233

235-
def __init__(self, logger: MyLogger):
236-
self.logger = logger
237-
238-
def __getattr__(self, name):
234+
def __getattr__(self, name: str) -> Any:
239235
def mock_function(*args, **kwargs):
240-
self.logger.warning(f"Plotly is not installed. Cannot use {name} visualization function.")
241-
return MockFigure()
236+
raise ImportError(f"Plotly is required to use '{name}'. " "Install it with uv pip install plotly")
242237

243238
return mock_function
244-
245-
246-
class MockFigure:
247-
"""Mock class for plotly.graph_objects.Figure when plotly is not installed."""
248-
249-
def __init__(self, *args, **kwargs):
250-
pass
251-
252-
253-
mock_plotly_go = type("MockPlotly", (), {"Figure": MockFigure})()

tests/test_other.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,9 @@
22
from bertopic.dimensionality import BaseDimensionalityReduction
33

44
try:
5-
import plotly.graph_objects as go
6-
7-
figure_type = go.Figure
5+
from plotly.graph_objects import Figure
86
except ImportError:
9-
from bertopic._utils import MockFigure
10-
11-
figure_type = MockFigure
7+
Figure = None
128

139

1410
def test_load_save_model():
@@ -41,5 +37,9 @@ def test_no_plotly():
4137
umap_model=BaseDimensionalityReduction(),
4238
)
4339
model.fit(["hello", "hi", "goodbye", "goodbye", "whats up"] * 10)
44-
out = model.visualize_topics()
45-
assert isinstance(out, figure_type)
40+
41+
try:
42+
out = model.visualize_topics()
43+
assert isinstance(out, Figure) if Figure else False
44+
except ImportError as e:
45+
assert "Plotly is required to use" in str(e)

0 commit comments

Comments
 (0)