Skip to content

Commit 95d3546

Browse files
feat(server): load santacoder/starcoder models with safetensors (#393)
Fix #366
1 parent c0928e6 commit 95d3546

File tree

2 files changed

+91
-91
lines changed

2 files changed

+91
-91
lines changed

launcher/src/main.rs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -546,11 +546,7 @@ enum LauncherError {
546546
WebserverCannotStart,
547547
}
548548

549-
fn download_convert_model(
550-
args: &Args,
551-
auto_convert: bool,
552-
running: Arc<AtomicBool>,
553-
) -> Result<(), LauncherError> {
549+
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
554550
let mut download_argv = vec![
555551
"text-generation-server".to_string(),
556552
"download-weights".to_string(),
@@ -562,11 +558,6 @@ fn download_convert_model(
562558
"--json-output".to_string(),
563559
];
564560

565-
// Auto convert weights to safetensors
566-
if auto_convert {
567-
download_argv.push("--auto-convert".to_string());
568-
}
569-
570561
// Model optional revision
571562
if let Some(revision) = &args.revision {
572563
download_argv.push("--revision".to_string());
@@ -932,11 +923,8 @@ fn main() -> Result<(), LauncherError> {
932923
})
933924
.expect("Error setting Ctrl-C handler");
934925

935-
// auto_convert is only needed for sharded models as we do not require safetensors in
936-
// single shard mode
937-
let auto_convert = num_shard > 1;
938926
// Download and convert model weights
939-
download_convert_model(&args, auto_convert, running.clone())?;
927+
download_convert_model(&args, running.clone())?;
940928

941929
// Shared shutdown bool
942930
let shutdown = Arc::new(Mutex::new(false));

server/text_generation_server/models/flash_santacoder.py

Lines changed: 89 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,7 @@ def __init__(
5454
)
5555

5656
# We do not use from_pretrained as we modified the model internal module layout
57-
try:
58-
filenames = weight_files(model_id, revision, ".bin")
59-
# Local files not found
60-
except LocalEntryNotFoundError:
61-
hub_files = weight_hub_files(model_id, revision, ".bin")
62-
filenames = download_weights(hub_files, model_id, revision)
57+
filenames = weight_files(model_id, revision, ".safetensors")
6358

6459
with init_empty_weights():
6560
model = FlashSantacoderForCausalLM(config)
@@ -91,85 +86,100 @@ def load_weights(
9186
transpose: bool,
9287
):
9388
for filename in filenames:
94-
state_dict = torch.load(filename, map_location="cpu")
95-
for key, value in state_dict.items():
96-
value = value.to(device if quantize is None else "cpu").to(dtype)
97-
98-
layer_name = ".".join(key.split(".")[:4])
99-
100-
# Fused qkv
101-
if "q_attn.weight" in key or "kv_attn.weight" in key:
102-
final_key = layer_name + ".c_attn.weight"
103-
elif "q_attn.bias" in key or "kv_attn.bias" in key:
104-
final_key = layer_name + ".c_attn.bias"
105-
106-
else:
107-
final_key = key
108-
109-
module_name, param_name = final_key.rsplit(".", 1)
110-
module = model.get_submodule(module_name)
111-
112-
try:
113-
current_parameter_tensor = module._parameters[param_name]
114-
except KeyError:
115-
current_parameter_tensor = None
116-
117-
if current_parameter_tensor is not None:
118-
if transpose and (
119-
"c_fc.weight" in key
120-
or "c_proj.weight" in key
121-
or "q_attn.weight" in key
122-
or "kv_attn.weight" in key
123-
or "c_attn.weight" in key
124-
):
125-
# Tranpose as we use nn.Linear instead of Conv1D
126-
value = value.T
127-
128-
if current_parameter_tensor.device == torch.device("meta"):
129-
# Init qkv
130-
if "c_attn.weight" in final_key:
131-
module._parameters[param_name] = value.new_empty(
132-
(
133-
model.transformer.head_size
134-
* (model.transformer.num_heads + 2),
135-
value.shape[1],
89+
with safe_open(
90+
filename, framework="pt", device=str(device) if quantize is None else "cpu"
91+
) as f:
92+
for key in f.keys():
93+
value = f.get_tensor(key)
94+
value = value.to(device if quantize is None else "cpu").to(dtype)
95+
96+
layer_name = ".".join(key.split(".")[:4])
97+
98+
# Fused qkv
99+
if "q_attn.weight" in key or "kv_attn.weight" in key:
100+
final_key = layer_name + ".c_attn.weight"
101+
elif "q_attn.bias" in key or "kv_attn.bias" in key:
102+
final_key = layer_name + ".c_attn.bias"
103+
104+
else:
105+
final_key = key
106+
107+
module_name, param_name = final_key.rsplit(".", 1)
108+
module = model.get_submodule(module_name)
109+
110+
try:
111+
current_parameter_tensor = module._parameters[param_name]
112+
except KeyError:
113+
current_parameter_tensor = None
114+
115+
if current_parameter_tensor is not None:
116+
if transpose and (
117+
"c_fc.weight" in key
118+
or "c_proj.weight" in key
119+
or "q_attn.weight" in key
120+
or "kv_attn.weight" in key
121+
or "c_attn.weight" in key
122+
):
123+
# Tranpose as we use nn.Linear instead of Conv1D
124+
value = value.T
125+
126+
if current_parameter_tensor.device == torch.device("meta"):
127+
# Init qkv
128+
if "c_attn.weight" in final_key:
129+
module._parameters[param_name] = value.new_empty(
130+
(
131+
model.transformer.head_size
132+
* (model.transformer.num_heads + 2),
133+
value.shape[1],
134+
)
136135
)
137-
)
138-
elif "c_attn.bias" in final_key:
139-
module._parameters[param_name] = value.new_empty(
140-
(
141-
model.transformer.head_size
142-
* (model.transformer.num_heads + 2)
136+
elif "c_attn.bias" in final_key:
137+
module._parameters[param_name] = value.new_empty(
138+
(
139+
model.transformer.head_size
140+
* (model.transformer.num_heads + 2)
141+
)
143142
)
144-
)
145143

146-
# Copy to correct slice
147-
if "q_attn.weight" in key:
148-
module._parameters[param_name][: value.shape[0]] = value
149-
elif "q_attn.bias" in key:
150-
module._parameters[param_name][: value.shape[0]] = value
151-
elif "kv_attn.weight" in key:
152-
module._parameters[param_name][
153-
model.transformer.head_size * model.transformer.num_heads :
154-
] = value
155-
elif "kv_attn.bias" in key:
156-
module._parameters[param_name][
157-
model.transformer.head_size * model.transformer.num_heads :
158-
] = value
144+
# Copy to correct slice
145+
if "q_attn.weight" in key:
146+
module._parameters[param_name][: value.shape[0]] = value
147+
elif "q_attn.bias" in key:
148+
module._parameters[param_name][: value.shape[0]] = value
149+
elif "kv_attn.weight" in key:
150+
module._parameters[param_name][
151+
model.transformer.head_size * model.transformer.num_heads :
152+
] = value
153+
elif "kv_attn.bias" in key:
154+
module._parameters[param_name][
155+
model.transformer.head_size * model.transformer.num_heads :
156+
] = value
157+
else:
158+
if current_parameter_tensor.shape != value.shape:
159+
raise ValueError(
160+
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
161+
)
162+
module._parameters[param_name] = value
159163
else:
160-
if current_parameter_tensor.shape != value.shape:
161-
raise ValueError(
162-
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
163-
)
164-
module._parameters[param_name] = value
165-
else:
166-
module._buffers[param_name] = value
164+
module._buffers[param_name] = value
167165

168-
del value
166+
del value
167+
168+
if model.lm_head.weight.device == torch.device("meta"):
169+
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
169170

170171
torch.cuda.empty_cache()
171172
model.post_load_weights(quantize)
172173

174+
uninitialized_parameters = []
175+
for n, p in model.named_parameters():
176+
if p.data.device == torch.device("meta"):
177+
uninitialized_parameters.append(n)
178+
if uninitialized_parameters:
179+
raise RuntimeError(
180+
f"found uninitialized parameters in model : {uninitialized_parameters}"
181+
)
182+
173183
def decode(self, generated_ids: List[int]) -> str:
174184
# Do not skip special tokens as they are used for custom parsing rules of the generated text
175185
return self.tokenizer.decode(
@@ -389,6 +399,8 @@ def load_weights(
389399
else:
390400
module._buffers[param_name] = tensor
391401

392-
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
402+
if model.lm_head.weight.device == torch.device("meta"):
403+
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
404+
393405
torch.cuda.empty_cache()
394406
model.post_load_weights(quantize)

0 commit comments

Comments
 (0)