@@ -49,18 +49,19 @@ def __init__(
49
49
model_type = "RefinedWeb" ,
50
50
vocab_size = 250880 ,
51
51
hidden_size = 64 ,
52
- n_layer = 2 ,
53
- n_head = 8 ,
52
+ num_hidden_layers = None ,
53
+ num_attention_heads = None ,
54
54
layer_norm_epsilon = 1e-5 ,
55
55
initializer_range = 0.02 ,
56
56
use_cache = True ,
57
57
bos_token_id = 1 ,
58
58
eos_token_id = 2 ,
59
59
hidden_dropout = 0.0 ,
60
60
attention_dropout = 0.0 ,
61
- n_head_kv = None ,
61
+ num_kv_heads = None ,
62
62
multi_query = False ,
63
63
alibi = False ,
64
+ new_decoder_architecture = None ,
64
65
bias = False ,
65
66
parallel_attn = False ,
66
67
** kwargs ,
@@ -78,8 +79,16 @@ def __init__(
78
79
# Backward compatibility with n_embed kwarg
79
80
n_embed = kwargs .pop ("n_embed" , None )
80
81
self .hidden_size = hidden_size if n_embed is None else n_embed
81
- self .n_layer = n_layer
82
- self .n_head = n_head
82
+ self .n_layer = (
83
+ num_hidden_layers
84
+ if num_hidden_layers is not None
85
+ else kwargs .pop ("n_layer" , 2 )
86
+ )
87
+ self .n_head = (
88
+ num_attention_heads
89
+ if num_attention_heads is not None
90
+ else kwargs .pop ("n_head" , 8 )
91
+ )
83
92
self .layer_norm_epsilon = layer_norm_epsilon
84
93
self .initializer_range = initializer_range
85
94
self .use_cache = use_cache
@@ -91,10 +100,21 @@ def __init__(
91
100
self .bos_token_id = bos_token_id
92
101
self .eos_token_id = eos_token_id
93
102
94
- if n_head_kv is not None :
95
- self .n_head_kv = n_head_kv
103
+ if num_kv_heads is not None :
104
+ self .n_head_kv = num_kv_heads
96
105
else :
97
- self .n_head_kv = 1 if multi_query else n_head
106
+ old_n_head_kv = kwargs .pop ("n_head_kv" , None )
107
+ if old_n_head_kv is not None :
108
+ self .n_head_kv = old_n_head_kv
109
+ else :
110
+ self .n_head_kv = 1 if multi_query else self .n_head
111
+
112
+ if new_decoder_architecture is not None :
113
+ self .new_decoder_architecture = new_decoder_architecture
114
+ elif model_type == "RefinedWeb" :
115
+ self .new_decoder_architecture = True
116
+ else :
117
+ self .new_decoder_architecture = False
98
118
99
119
super ().__init__ (bos_token_id = bos_token_id , eos_token_id = eos_token_id , ** kwargs )
100
120
@@ -530,26 +550,23 @@ def __init__(self, config, weights):
530
550
self .word_embeddings = TensorParallelEmbedding (
531
551
prefix = "transformer.word_embeddings" , weights = weights
532
552
)
533
- if config .model_type == "RefinedWebModel" :
553
+
554
+ if config .new_decoder_architecture :
534
555
self .h = nn .ModuleList (
535
556
[
536
- FlashRWLayer (layer_id , config , weights )
557
+ FlashRWLargeLayer (layer_id , config , weights )
537
558
for layer_id in range (config .num_hidden_layers )
538
559
]
539
560
)
540
- self .cache_size = self .h [0 ].self_attention .num_heads_kv
541
- elif config . model_type == "RefinedWeb" :
561
+ self .cache_size = self .h [0 ].self_attention .num_groups
562
+ else :
542
563
self .h = nn .ModuleList (
543
564
[
544
- FlashRWLargeLayer (layer_id , config , weights )
565
+ FlashRWLayer (layer_id , config , weights )
545
566
for layer_id in range (config .num_hidden_layers )
546
567
]
547
568
)
548
- self .cache_size = self .h [0 ].self_attention .num_groups
549
- else :
550
- raise NotImplementedError (
551
- f"model_type { config .model_type } is not supported."
552
- )
569
+ self .cache_size = self .h [0 ].self_attention .num_heads_kv
553
570
554
571
self .ln_f = FastLayerNorm .load (
555
572
prefix = "transformer.ln_f" ,
0 commit comments