@@ -54,12 +54,7 @@ def __init__(
54
54
)
55
55
56
56
# We do not use from_pretrained as we modified the model internal module layout
57
- try :
58
- filenames = weight_files (model_id , revision , ".bin" )
59
- # Local files not found
60
- except LocalEntryNotFoundError :
61
- hub_files = weight_hub_files (model_id , revision , ".bin" )
62
- filenames = download_weights (hub_files , model_id , revision )
57
+ filenames = weight_files (model_id , revision , ".safetensors" )
63
58
64
59
with init_empty_weights ():
65
60
model = FlashSantacoderForCausalLM (config )
@@ -91,85 +86,100 @@ def load_weights(
91
86
transpose : bool ,
92
87
):
93
88
for filename in filenames :
94
- state_dict = torch .load (filename , map_location = "cpu" )
95
- for key , value in state_dict .items ():
96
- value = value .to (device if quantize is None else "cpu" ).to (dtype )
97
-
98
- layer_name = "." .join (key .split ("." )[:4 ])
99
-
100
- # Fused qkv
101
- if "q_attn.weight" in key or "kv_attn.weight" in key :
102
- final_key = layer_name + ".c_attn.weight"
103
- elif "q_attn.bias" in key or "kv_attn.bias" in key :
104
- final_key = layer_name + ".c_attn.bias"
105
-
106
- else :
107
- final_key = key
108
-
109
- module_name , param_name = final_key .rsplit ("." , 1 )
110
- module = model .get_submodule (module_name )
111
-
112
- try :
113
- current_parameter_tensor = module ._parameters [param_name ]
114
- except KeyError :
115
- current_parameter_tensor = None
116
-
117
- if current_parameter_tensor is not None :
118
- if transpose and (
119
- "c_fc.weight" in key
120
- or "c_proj.weight" in key
121
- or "q_attn.weight" in key
122
- or "kv_attn.weight" in key
123
- or "c_attn.weight" in key
124
- ):
125
- # Tranpose as we use nn.Linear instead of Conv1D
126
- value = value .T
127
-
128
- if current_parameter_tensor .device == torch .device ("meta" ):
129
- # Init qkv
130
- if "c_attn.weight" in final_key :
131
- module ._parameters [param_name ] = value .new_empty (
132
- (
133
- model .transformer .head_size
134
- * (model .transformer .num_heads + 2 ),
135
- value .shape [1 ],
89
+ with safe_open (
90
+ filename , framework = "pt" , device = str (device ) if quantize is None else "cpu"
91
+ ) as f :
92
+ for key in f .keys ():
93
+ value = f .get_tensor (key )
94
+ value = value .to (device if quantize is None else "cpu" ).to (dtype )
95
+
96
+ layer_name = "." .join (key .split ("." )[:4 ])
97
+
98
+ # Fused qkv
99
+ if "q_attn.weight" in key or "kv_attn.weight" in key :
100
+ final_key = layer_name + ".c_attn.weight"
101
+ elif "q_attn.bias" in key or "kv_attn.bias" in key :
102
+ final_key = layer_name + ".c_attn.bias"
103
+
104
+ else :
105
+ final_key = key
106
+
107
+ module_name , param_name = final_key .rsplit ("." , 1 )
108
+ module = model .get_submodule (module_name )
109
+
110
+ try :
111
+ current_parameter_tensor = module ._parameters [param_name ]
112
+ except KeyError :
113
+ current_parameter_tensor = None
114
+
115
+ if current_parameter_tensor is not None :
116
+ if transpose and (
117
+ "c_fc.weight" in key
118
+ or "c_proj.weight" in key
119
+ or "q_attn.weight" in key
120
+ or "kv_attn.weight" in key
121
+ or "c_attn.weight" in key
122
+ ):
123
+ # Tranpose as we use nn.Linear instead of Conv1D
124
+ value = value .T
125
+
126
+ if current_parameter_tensor .device == torch .device ("meta" ):
127
+ # Init qkv
128
+ if "c_attn.weight" in final_key :
129
+ module ._parameters [param_name ] = value .new_empty (
130
+ (
131
+ model .transformer .head_size
132
+ * (model .transformer .num_heads + 2 ),
133
+ value .shape [1 ],
134
+ )
136
135
)
137
- )
138
- elif "c_attn.bias" in final_key :
139
- module . _parameters [ param_name ] = value . new_empty (
140
- (
141
- model .transformer .head_size
142
- * ( model . transformer . num_heads + 2 )
136
+ elif "c_attn.bias" in final_key :
137
+ module . _parameters [ param_name ] = value . new_empty (
138
+ (
139
+ model . transformer . head_size
140
+ * ( model .transformer .num_heads + 2 )
141
+ )
143
142
)
144
- )
145
143
146
- # Copy to correct slice
147
- if "q_attn.weight" in key :
148
- module ._parameters [param_name ][: value .shape [0 ]] = value
149
- elif "q_attn.bias" in key :
150
- module ._parameters [param_name ][: value .shape [0 ]] = value
151
- elif "kv_attn.weight" in key :
152
- module ._parameters [param_name ][
153
- model .transformer .head_size * model .transformer .num_heads :
154
- ] = value
155
- elif "kv_attn.bias" in key :
156
- module ._parameters [param_name ][
157
- model .transformer .head_size * model .transformer .num_heads :
158
- ] = value
144
+ # Copy to correct slice
145
+ if "q_attn.weight" in key :
146
+ module ._parameters [param_name ][: value .shape [0 ]] = value
147
+ elif "q_attn.bias" in key :
148
+ module ._parameters [param_name ][: value .shape [0 ]] = value
149
+ elif "kv_attn.weight" in key :
150
+ module ._parameters [param_name ][
151
+ model .transformer .head_size * model .transformer .num_heads :
152
+ ] = value
153
+ elif "kv_attn.bias" in key :
154
+ module ._parameters [param_name ][
155
+ model .transformer .head_size * model .transformer .num_heads :
156
+ ] = value
157
+ else :
158
+ if current_parameter_tensor .shape != value .shape :
159
+ raise ValueError (
160
+ f"Name { final_key } -- Current { current_parameter_tensor .shape } and got { value .shape } "
161
+ )
162
+ module ._parameters [param_name ] = value
159
163
else :
160
- if current_parameter_tensor .shape != value .shape :
161
- raise ValueError (
162
- f"Name { final_key } -- Current { current_parameter_tensor .shape } and got { value .shape } "
163
- )
164
- module ._parameters [param_name ] = value
165
- else :
166
- module ._buffers [param_name ] = value
164
+ module ._buffers [param_name ] = value
167
165
168
- del value
166
+ del value
167
+
168
+ if model .lm_head .weight .device == torch .device ("meta" ):
169
+ model .lm_head .weight = torch .nn .Parameter (model .transformer .wte .weight )
169
170
170
171
torch .cuda .empty_cache ()
171
172
model .post_load_weights (quantize )
172
173
174
+ uninitialized_parameters = []
175
+ for n , p in model .named_parameters ():
176
+ if p .data .device == torch .device ("meta" ):
177
+ uninitialized_parameters .append (n )
178
+ if uninitialized_parameters :
179
+ raise RuntimeError (
180
+ f"found uninitialized parameters in model : { uninitialized_parameters } "
181
+ )
182
+
173
183
def decode (self , generated_ids : List [int ]) -> str :
174
184
# Do not skip special tokens as they are used for custom parsing rules of the generated text
175
185
return self .tokenizer .decode (
@@ -389,6 +399,8 @@ def load_weights(
389
399
else :
390
400
module ._buffers [param_name ] = tensor
391
401
392
- model .lm_head .weight = torch .nn .Parameter (model .transformer .wte .weight )
402
+ if model .lm_head .weight .device == torch .device ("meta" ):
403
+ model .lm_head .weight = torch .nn .Parameter (model .transformer .wte .weight )
404
+
393
405
torch .cuda .empty_cache ()
394
406
model .post_load_weights (quantize )
0 commit comments