Skip to content

Commit b09d4cc

Browse files
committed
port huggingface#3188 to gaudi backend
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 8394776 commit b09d4cc

File tree

11 files changed

+892
-510
lines changed

11 files changed

+892
-510
lines changed

backends/gaudi/server/text_generation_server/models/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,6 @@
8383
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
8484
FlashGPTNeoXForCausalLM,
8585
)
86-
from text_generation_server.models.pali_gemma import (
87-
PaliGemmaBatch,
88-
)
8986
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
9087
PaliGemmaForConditionalGeneration,
9188
)
@@ -153,7 +150,6 @@
153150
)
154151

155152
VLM_BATCH_TYPES = {
156-
PaliGemmaBatch,
157153
FlashVlmCausalLMBatch,
158154
FlashMllamaCausalLMBatch,
159155
}
@@ -635,6 +631,7 @@ def get_model(
635631
default_dtype=torch.bfloat16,
636632
trust_remote_code=trust_remote_code,
637633
lora_adapter_ids=lora_adapter_ids,
634+
support_chunking=False,
638635
)
639636
elif model_type == BAICHUAN:
640637
return FlashCausalLM(
@@ -784,6 +781,8 @@ def get_model(
784781
kv_cache_dtype=kv_cache_dtype,
785782
trust_remote_code=trust_remote_code,
786783
lora_adapter_ids=lora_adapter_ids,
784+
# TODO: Fix bug in rust image_text_replacement implementation
785+
support_chunking=False,
787786
)
788787
elif model_type == QWEN2_5_VL:
789788
return FlashVlmCausalLM(
@@ -799,6 +798,8 @@ def get_model(
799798
lora_adapter_ids=lora_adapter_ids,
800799
config_class=Qwen2_5_VLConfig,
801800
processor_class=Qwen2_5_VLProcessor,
801+
# TODO: Fix bug in rust image_text_replacement implementation
802+
support_chunking=False,
802803
)
803804
elif model_type == QWEN3:
804805
return FlashCausalLM(
@@ -824,6 +825,7 @@ def get_model(
824825
default_dtype=torch.bfloat16,
825826
trust_remote_code=trust_remote_code,
826827
lora_adapter_ids=lora_adapter_ids,
828+
support_chunking=False,
827829
)
828830
elif model_type == IDEFICS2:
829831
return FlashVlmCausalLM(
@@ -868,7 +870,6 @@ def get_model(
868870
default_dtype=torch.bfloat16,
869871
trust_remote_code=trust_remote_code,
870872
lora_adapter_ids=lora_adapter_ids,
871-
batch_class=PaliGemmaBatch,
872873
)
873874
elif model_type == LLAVA_NEXT:
874875
return FlashVlmCausalLM(

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py

Lines changed: 104 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -163,111 +163,124 @@ def _merge_input_ids_with_image_features(
163163
)
164164
return inputs_embeds
165165

166-
def forward(
166+
def get_vision_embeds(
167167
self,
168-
input_ids: torch.Tensor,
169-
position_ids: torch.Tensor,
170-
cu_seqlen_prefill: Optional[torch.Tensor],
171-
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
172-
slots: torch.Tensor,
173-
seqlen: Seqlen,
174-
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
175-
lm_head_indices: Optional[torch.Tensor] = None,
176-
pixel_values: torch.FloatTensor = None,
177-
# Unused for this model
178-
pixel_attention_mask=None,
179-
image_sizes: Optional[torch.LongTensor] = None,
180-
adapter_data: Optional[torch.Tensor] = None,
168+
pixel_values: torch.FloatTensor,
169+
pixel_attention_mask: Optional[torch.FloatTensor] = None,
170+
image_sizes: Optional[torch.Tensor] = None,
181171
image_grid_thw: Optional[torch.LongTensor] = None,
182172
):
183-
inputs_embeds = self.text_model.embed_tokens(input_ids)
184-
if pixel_values is not None and len(pixel_values) > 0:
185-
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
186-
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
187-
# 1. Extract the input embeddings
188-
189-
# 2. Merge text and images
190-
num_images, num_patches, channels, height, width = pixel_values.shape
191-
pixel_values = pixel_values.view(
192-
num_images * num_patches, channels, height, width
193-
)
194-
image_features = self.vision_tower(pixel_values)
173+
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
174+
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
175+
# 1. Extract the input embeddings
176+
177+
# 2. Merge text and images
178+
num_images, num_patches, channels, height, width = pixel_values.shape
179+
pixel_values = pixel_values.view(
180+
num_images * num_patches, channels, height, width
181+
)
182+
image_features = self.vision_tower(pixel_values)
195183

196-
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
197-
# Already done within the clip model
198-
selected_image_feature = image_features.last_hidden_state
184+
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
185+
# Already done within the clip model
186+
selected_image_feature = image_features.last_hidden_state
199187

200-
if self.config.vision_feature_select_strategy == "default":
201-
selected_image_feature = selected_image_feature[:, 1:]
202-
elif self.config.vision_feature_select_strategy == "full":
203-
selected_image_feature = selected_image_feature
204-
else:
205-
raise RuntimeError(
206-
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
207-
)
188+
if self.config.vision_feature_select_strategy == "default":
189+
selected_image_feature = selected_image_feature[:, 1:]
190+
elif self.config.vision_feature_select_strategy == "full":
191+
selected_image_feature = selected_image_feature
192+
else:
193+
raise RuntimeError(
194+
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
195+
)
208196

209-
image_features = self.multi_modal_projector(selected_image_feature)
197+
image_features = self.multi_modal_projector(selected_image_feature)
210198

211-
# split up image_features for each of the individual images
212-
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
213-
# if we assume each image has 5 image features (base image + 4 patches)
214-
split_sizes = [num_patches] * num_images
215-
image_features = torch.split(image_features, split_sizes, dim=0)
199+
# split up image_features for each of the individual images
200+
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
201+
# if we assume each image has 5 image features (base image + 4 patches)
202+
split_sizes = [num_patches] * num_images
203+
image_features = torch.split(image_features, split_sizes, dim=0)
216204

217-
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
218-
height = width = (
219-
self.config.vision_config.image_size
220-
// self.config.vision_config.patch_size
221-
)
205+
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
206+
height = width = (
207+
self.config.vision_config.image_size // self.config.vision_config.patch_size
208+
)
222209

223-
new_image_features = []
224-
for image_idx, image_feature in enumerate(image_features):
225-
if image_feature.shape[0] > 1:
226-
base_image_feature = image_feature[0]
227-
image_feature = image_feature[1:]
228-
229-
if height * width != base_image_feature.shape[0]:
230-
raise ValueError(
231-
"The number of patches is not consistent with the image size."
232-
)
233-
234-
# Dimensions are intentionally swapped to be bug-compatible with
235-
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
236-
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
237-
image_sizes[image_idx],
238-
self.config.image_grid_pinpoints,
239-
self.config.vision_config.image_size,
240-
)
241-
image_feature = image_feature.view(
242-
num_patch_height, num_patch_width, height, width, -1
210+
new_image_features = []
211+
for image_idx, image_feature in enumerate(image_features):
212+
if image_feature.shape[0] > 1:
213+
base_image_feature = image_feature[0]
214+
image_feature = image_feature[1:]
215+
216+
if height * width != base_image_feature.shape[0]:
217+
raise ValueError(
218+
"The number of patches is not consistent with the image size."
243219
)
244-
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
245-
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
246-
image_feature = unpad_image(image_feature, image_sizes[image_idx])
247-
image_feature = torch.cat(
248-
(
249-
image_feature,
250-
self.image_newline[:, None, None].expand(
251-
*image_feature.shape[:-1], 1
252-
),
220+
221+
# Dimensions are intentionally swapped to be bug-compatible with
222+
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
223+
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
224+
image_sizes[image_idx],
225+
self.config.image_grid_pinpoints,
226+
self.config.vision_config.image_size,
227+
)
228+
image_feature = image_feature.view(
229+
num_patch_height, num_patch_width, height, width, -1
230+
)
231+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
232+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
233+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
234+
image_feature = torch.cat(
235+
(
236+
image_feature,
237+
self.image_newline[:, None, None].expand(
238+
*image_feature.shape[:-1], 1
253239
),
254-
dim=-1,
255-
)
256-
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
257-
image_feature = torch.cat(
258-
(base_image_feature, image_feature), dim=0
259-
)
260-
else:
261-
image_feature = image_feature[0]
262-
image_feature = torch.cat(
263-
(image_feature, self.image_newline[None]), dim=0
264-
)
265-
new_image_features.append(image_feature)
266-
image_features = torch.stack(new_image_features, dim=0)
240+
),
241+
dim=-1,
242+
)
243+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
244+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
245+
else:
246+
image_feature = image_feature[0]
247+
image_feature = torch.cat(
248+
(image_feature, self.image_newline[None]), dim=0
249+
)
250+
new_image_features.append(image_feature)
251+
image_features = torch.stack(new_image_features, dim=0)
252+
return image_features.view(-1, image_features.shape[-1])
253+
254+
def get_inputs_embeds(
255+
self,
256+
input_ids: torch.Tensor,
257+
vision_embeds: torch.Tensor = None,
258+
pixel_values: torch.FloatTensor = None,
259+
image_sizes: Optional[torch.LongTensor] = None,
260+
):
261+
inputs_embeds = self.text_model.embed_tokens(input_ids)
267262

263+
if vision_embeds is not None:
264+
# When we generate, we don't want to replace the potential image_token_id that we generated by images
265+
# that simply don't exist
268266
inputs_embeds = self._merge_input_ids_with_image_features(
269-
input_ids, inputs_embeds, image_features
267+
input_ids, inputs_embeds, vision_embeds
270268
)
269+
return inputs_embeds
270+
271+
def forward(
272+
self,
273+
inputs_embeds: torch.Tensor,
274+
position_ids: torch.Tensor,
275+
cu_seqlen_prefill: Optional[torch.Tensor],
276+
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
277+
slots: torch.Tensor,
278+
seqlen: Seqlen,
279+
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
280+
lm_head_indices: Optional[torch.Tensor] = None,
281+
attention_mask: Optional[torch.BoolTensor] = None,
282+
adapter_data: Optional[torch.Tensor] = None,
283+
):
271284

272285
hidden_states = self.text_model.model(
273286
inputs_embeds=inputs_embeds,

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,43 +62,54 @@ def __init__(self, prefix, config, weights):
6262
self.pad_token_id = (
6363
config.pad_token_id if config.pad_token_id is not None else -1
6464
)
65+
self.dtype = weights.dtype
6566

66-
def forward(
67+
def get_vision_embeds(
68+
self,
69+
pixel_values: torch.FloatTensor,
70+
pixel_attention_mask: Optional[torch.FloatTensor] = None,
71+
image_sizes: Optional[torch.Tensor] = None,
72+
image_grid_thw: Optional[torch.LongTensor] = None,
73+
):
74+
pixel_values = pixel_values.to(dtype=self.dtype)
75+
image_outputs = self.vision_tower(pixel_values)
76+
last_hidden_state = self.post_vision_tower_layernorm(
77+
image_outputs.last_hidden_state
78+
)
79+
image_features = self.multi_modal_projector(last_hidden_state)
80+
image_features = image_features.view(-1, image_features.shape[-1])
81+
return image_features
82+
83+
def get_inputs_embeds(
6784
self,
6885
input_ids: torch.Tensor,
86+
vision_embeds: torch.Tensor = None,
87+
):
88+
inputs_embeds = self.text_model.embed_tokens(input_ids)
89+
90+
if vision_embeds is not None:
91+
mask = input_ids == self.config.image_token_index
92+
inputs_embeds[mask] = vision_embeds
93+
94+
return inputs_embeds
95+
96+
def forward(
97+
self,
98+
inputs_embeds: torch.Tensor,
6999
position_ids: torch.Tensor,
70100
cu_seqlen_prefill: Optional[torch.Tensor],
71101
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
72102
slots: torch.Tensor,
73103
seqlen: Seqlen,
74104
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
75105
lm_head_indices: Optional[torch.Tensor] = None,
76-
pixel_values: torch.FloatTensor = None,
77-
# Unused here
78-
pixel_attention_mask: Optional[torch.BoolTensor] = None,
79-
image_sizes: Optional[torch.Tensor] = None,
106+
attention_mask: Optional[torch.BoolTensor] = None,
80107
adapter_data: Optional[torch.Tensor] = None,
81-
image_grid_thw: Optional[torch.LongTensor] = None,
82108
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
83-
inputs_embeds = self.text_model.embed_tokens(input_ids)
84109
# TODO This is odd but apparently pali gemma position ids start at 1.
85110
if cu_seqlen_prefill is not None:
86111
position_ids += 1
87112

88-
if pixel_values is not None:
89-
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
90-
image_outputs = self.vision_tower(pixel_values)
91-
last_hidden_state = self.post_vision_tower_layernorm(
92-
image_outputs.last_hidden_state
93-
)
94-
image_features = self.multi_modal_projector(last_hidden_state)
95-
96-
# mask where image or padding tokens
97-
mask = input_ids == self.config.image_token_index
98-
99-
# insert image features into input embeddings
100-
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
101-
102113
hidden_states = self.text_model.model(
103114
inputs_embeds=inputs_embeds,
104115
position_ids=position_ids,

0 commit comments

Comments
 (0)