Skip to content

Commit ab96b9a

Browse files
feat(server): support new falcon config (#712)
1 parent 2efd46e commit ab96b9a

File tree

2 files changed

+38
-26
lines changed

2 files changed

+38
-26
lines changed

server/text_generation_server/models/__init__.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,10 @@ def get_model(
200200
trust_remote_code=trust_remote_code,
201201
)
202202

203-
if model_type in ["RefinedWeb", "RefinedWebModel"]:
203+
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
204204
if sharded:
205205
if FLASH_ATTENTION:
206-
if config_dict.get("alibi", False) or (
207-
model_type == "RefinedWebModel"
208-
and config_dict.get("multi_query", True)
209-
):
206+
if config_dict.get("alibi", False):
210207
raise NotImplementedError("sharded is not supported for this model")
211208
return FlashRWSharded(
212209
model_id,
@@ -215,9 +212,7 @@ def get_model(
215212
dtype=dtype,
216213
trust_remote_code=trust_remote_code,
217214
)
218-
raise NotImplementedError(
219-
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
220-
)
215+
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
221216
else:
222217
if FLASH_ATTENTION and not config_dict.get("alibi", False):
223218
return FlashRWSharded(

server/text_generation_server/models/custom_modeling/flash_rw_modeling.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,19 @@ def __init__(
4949
model_type="RefinedWeb",
5050
vocab_size=250880,
5151
hidden_size=64,
52-
n_layer=2,
53-
n_head=8,
52+
num_hidden_layers=None,
53+
num_attention_heads=None,
5454
layer_norm_epsilon=1e-5,
5555
initializer_range=0.02,
5656
use_cache=True,
5757
bos_token_id=1,
5858
eos_token_id=2,
5959
hidden_dropout=0.0,
6060
attention_dropout=0.0,
61-
n_head_kv=None,
61+
num_kv_heads=None,
6262
multi_query=False,
6363
alibi=False,
64+
new_decoder_architecture=None,
6465
bias=False,
6566
parallel_attn=False,
6667
**kwargs,
@@ -78,8 +79,16 @@ def __init__(
7879
# Backward compatibility with n_embed kwarg
7980
n_embed = kwargs.pop("n_embed", None)
8081
self.hidden_size = hidden_size if n_embed is None else n_embed
81-
self.n_layer = n_layer
82-
self.n_head = n_head
82+
self.n_layer = (
83+
num_hidden_layers
84+
if num_hidden_layers is not None
85+
else kwargs.pop("n_layer", 2)
86+
)
87+
self.n_head = (
88+
num_attention_heads
89+
if num_attention_heads is not None
90+
else kwargs.pop("n_head", 8)
91+
)
8392
self.layer_norm_epsilon = layer_norm_epsilon
8493
self.initializer_range = initializer_range
8594
self.use_cache = use_cache
@@ -91,10 +100,21 @@ def __init__(
91100
self.bos_token_id = bos_token_id
92101
self.eos_token_id = eos_token_id
93102

94-
if n_head_kv is not None:
95-
self.n_head_kv = n_head_kv
103+
if num_kv_heads is not None:
104+
self.n_head_kv = num_kv_heads
96105
else:
97-
self.n_head_kv = 1 if multi_query else n_head
106+
old_n_head_kv = kwargs.pop("n_head_kv", None)
107+
if old_n_head_kv is not None:
108+
self.n_head_kv = old_n_head_kv
109+
else:
110+
self.n_head_kv = 1 if multi_query else self.n_head
111+
112+
if new_decoder_architecture is not None:
113+
self.new_decoder_architecture = new_decoder_architecture
114+
elif model_type == "RefinedWeb":
115+
self.new_decoder_architecture = True
116+
else:
117+
self.new_decoder_architecture = False
98118

99119
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
100120

@@ -530,26 +550,23 @@ def __init__(self, config, weights):
530550
self.word_embeddings = TensorParallelEmbedding(
531551
prefix="transformer.word_embeddings", weights=weights
532552
)
533-
if config.model_type == "RefinedWebModel":
553+
554+
if config.new_decoder_architecture:
534555
self.h = nn.ModuleList(
535556
[
536-
FlashRWLayer(layer_id, config, weights)
557+
FlashRWLargeLayer(layer_id, config, weights)
537558
for layer_id in range(config.num_hidden_layers)
538559
]
539560
)
540-
self.cache_size = self.h[0].self_attention.num_heads_kv
541-
elif config.model_type == "RefinedWeb":
561+
self.cache_size = self.h[0].self_attention.num_groups
562+
else:
542563
self.h = nn.ModuleList(
543564
[
544-
FlashRWLargeLayer(layer_id, config, weights)
565+
FlashRWLayer(layer_id, config, weights)
545566
for layer_id in range(config.num_hidden_layers)
546567
]
547568
)
548-
self.cache_size = self.h[0].self_attention.num_groups
549-
else:
550-
raise NotImplementedError(
551-
f"model_type {config.model_type} is not supported."
552-
)
569+
self.cache_size = self.h[0].self_attention.num_heads_kv
553570

554571
self.ln_f = FastLayerNorm.load(
555572
prefix="transformer.ln_f",

0 commit comments

Comments
 (0)