@@ -163,111 +163,124 @@ def _merge_input_ids_with_image_features(
163
163
)
164
164
return inputs_embeds
165
165
166
- def forward (
166
+ def get_vision_embeds (
167
167
self ,
168
- input_ids : torch .Tensor ,
169
- position_ids : torch .Tensor ,
170
- cu_seqlen_prefill : Optional [torch .Tensor ],
171
- kv_cache : List [Tuple [torch .Tensor , torch .Tensor ]],
172
- slots : torch .Tensor ,
173
- seqlen : Seqlen ,
174
- hpu_attention_meta : Optional [HPUPagedAttentionMetadata ],
175
- lm_head_indices : Optional [torch .Tensor ] = None ,
176
- pixel_values : torch .FloatTensor = None ,
177
- # Unused for this model
178
- pixel_attention_mask = None ,
179
- image_sizes : Optional [torch .LongTensor ] = None ,
180
- adapter_data : Optional [torch .Tensor ] = None ,
168
+ pixel_values : torch .FloatTensor ,
169
+ pixel_attention_mask : Optional [torch .FloatTensor ] = None ,
170
+ image_sizes : Optional [torch .Tensor ] = None ,
181
171
image_grid_thw : Optional [torch .LongTensor ] = None ,
182
172
):
183
- inputs_embeds = self .text_model .embed_tokens (input_ids )
184
- if pixel_values is not None and len (pixel_values ) > 0 :
185
- # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
186
- # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
187
- # 1. Extract the input embeddings
188
-
189
- # 2. Merge text and images
190
- num_images , num_patches , channels , height , width = pixel_values .shape
191
- pixel_values = pixel_values .view (
192
- num_images * num_patches , channels , height , width
193
- )
194
- image_features = self .vision_tower (pixel_values )
173
+ # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
174
+ # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
175
+ # 1. Extract the input embeddings
176
+
177
+ # 2. Merge text and images
178
+ num_images , num_patches , channels , height , width = pixel_values .shape
179
+ pixel_values = pixel_values .view (
180
+ num_images * num_patches , channels , height , width
181
+ )
182
+ image_features = self .vision_tower (pixel_values )
195
183
196
- # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
197
- # Already done within the clip model
198
- selected_image_feature = image_features .last_hidden_state
184
+ # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
185
+ # Already done within the clip model
186
+ selected_image_feature = image_features .last_hidden_state
199
187
200
- if self .config .vision_feature_select_strategy == "default" :
201
- selected_image_feature = selected_image_feature [:, 1 :]
202
- elif self .config .vision_feature_select_strategy == "full" :
203
- selected_image_feature = selected_image_feature
204
- else :
205
- raise RuntimeError (
206
- f"Strategy `{ self .config .vision_feature_select_strategy } ` is not supported/valid."
207
- )
188
+ if self .config .vision_feature_select_strategy == "default" :
189
+ selected_image_feature = selected_image_feature [:, 1 :]
190
+ elif self .config .vision_feature_select_strategy == "full" :
191
+ selected_image_feature = selected_image_feature
192
+ else :
193
+ raise RuntimeError (
194
+ f"Strategy `{ self .config .vision_feature_select_strategy } ` is not supported/valid."
195
+ )
208
196
209
- image_features = self .multi_modal_projector (selected_image_feature )
197
+ image_features = self .multi_modal_projector (selected_image_feature )
210
198
211
- # split up image_features for each of the individual images
212
- # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
213
- # if we assume each image has 5 image features (base image + 4 patches)
214
- split_sizes = [num_patches ] * num_images
215
- image_features = torch .split (image_features , split_sizes , dim = 0 )
199
+ # split up image_features for each of the individual images
200
+ # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
201
+ # if we assume each image has 5 image features (base image + 4 patches)
202
+ split_sizes = [num_patches ] * num_images
203
+ image_features = torch .split (image_features , split_sizes , dim = 0 )
216
204
217
- # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
218
- height = width = (
219
- self .config .vision_config .image_size
220
- // self .config .vision_config .patch_size
221
- )
205
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
206
+ height = width = (
207
+ self .config .vision_config .image_size // self .config .vision_config .patch_size
208
+ )
222
209
223
- new_image_features = []
224
- for image_idx , image_feature in enumerate (image_features ):
225
- if image_feature .shape [0 ] > 1 :
226
- base_image_feature = image_feature [0 ]
227
- image_feature = image_feature [1 :]
228
-
229
- if height * width != base_image_feature .shape [0 ]:
230
- raise ValueError (
231
- "The number of patches is not consistent with the image size."
232
- )
233
-
234
- # Dimensions are intentionally swapped to be bug-compatible with
235
- # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
236
- num_patch_width , num_patch_height = get_anyres_image_grid_shape (
237
- image_sizes [image_idx ],
238
- self .config .image_grid_pinpoints ,
239
- self .config .vision_config .image_size ,
240
- )
241
- image_feature = image_feature .view (
242
- num_patch_height , num_patch_width , height , width , - 1
210
+ new_image_features = []
211
+ for image_idx , image_feature in enumerate (image_features ):
212
+ if image_feature .shape [0 ] > 1 :
213
+ base_image_feature = image_feature [0 ]
214
+ image_feature = image_feature [1 :]
215
+
216
+ if height * width != base_image_feature .shape [0 ]:
217
+ raise ValueError (
218
+ "The number of patches is not consistent with the image size."
243
219
)
244
- image_feature = image_feature .permute (4 , 0 , 2 , 1 , 3 ).contiguous ()
245
- image_feature = image_feature .flatten (1 , 2 ).flatten (2 , 3 )
246
- image_feature = unpad_image (image_feature , image_sizes [image_idx ])
247
- image_feature = torch .cat (
248
- (
249
- image_feature ,
250
- self .image_newline [:, None , None ].expand (
251
- * image_feature .shape [:- 1 ], 1
252
- ),
220
+
221
+ # Dimensions are intentionally swapped to be bug-compatible with
222
+ # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
223
+ num_patch_width , num_patch_height = get_anyres_image_grid_shape (
224
+ image_sizes [image_idx ],
225
+ self .config .image_grid_pinpoints ,
226
+ self .config .vision_config .image_size ,
227
+ )
228
+ image_feature = image_feature .view (
229
+ num_patch_height , num_patch_width , height , width , - 1
230
+ )
231
+ image_feature = image_feature .permute (4 , 0 , 2 , 1 , 3 ).contiguous ()
232
+ image_feature = image_feature .flatten (1 , 2 ).flatten (2 , 3 )
233
+ image_feature = unpad_image (image_feature , image_sizes [image_idx ])
234
+ image_feature = torch .cat (
235
+ (
236
+ image_feature ,
237
+ self .image_newline [:, None , None ].expand (
238
+ * image_feature .shape [:- 1 ], 1
253
239
),
254
- dim = - 1 ,
255
- )
256
- image_feature = image_feature .flatten (1 , 2 ).transpose (0 , 1 )
257
- image_feature = torch .cat (
258
- (base_image_feature , image_feature ), dim = 0
259
- )
260
- else :
261
- image_feature = image_feature [0 ]
262
- image_feature = torch .cat (
263
- (image_feature , self .image_newline [None ]), dim = 0
264
- )
265
- new_image_features .append (image_feature )
266
- image_features = torch .stack (new_image_features , dim = 0 )
240
+ ),
241
+ dim = - 1 ,
242
+ )
243
+ image_feature = image_feature .flatten (1 , 2 ).transpose (0 , 1 )
244
+ image_feature = torch .cat ((base_image_feature , image_feature ), dim = 0 )
245
+ else :
246
+ image_feature = image_feature [0 ]
247
+ image_feature = torch .cat (
248
+ (image_feature , self .image_newline [None ]), dim = 0
249
+ )
250
+ new_image_features .append (image_feature )
251
+ image_features = torch .stack (new_image_features , dim = 0 )
252
+ return image_features .view (- 1 , image_features .shape [- 1 ])
253
+
254
+ def get_inputs_embeds (
255
+ self ,
256
+ input_ids : torch .Tensor ,
257
+ vision_embeds : torch .Tensor = None ,
258
+ pixel_values : torch .FloatTensor = None ,
259
+ image_sizes : Optional [torch .LongTensor ] = None ,
260
+ ):
261
+ inputs_embeds = self .text_model .embed_tokens (input_ids )
267
262
263
+ if vision_embeds is not None :
264
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
265
+ # that simply don't exist
268
266
inputs_embeds = self ._merge_input_ids_with_image_features (
269
- input_ids , inputs_embeds , image_features
267
+ input_ids , inputs_embeds , vision_embeds
270
268
)
269
+ return inputs_embeds
270
+
271
+ def forward (
272
+ self ,
273
+ inputs_embeds : torch .Tensor ,
274
+ position_ids : torch .Tensor ,
275
+ cu_seqlen_prefill : Optional [torch .Tensor ],
276
+ kv_cache : List [Tuple [torch .Tensor , torch .Tensor ]],
277
+ slots : torch .Tensor ,
278
+ seqlen : Seqlen ,
279
+ hpu_attention_meta : Optional [HPUPagedAttentionMetadata ],
280
+ lm_head_indices : Optional [torch .Tensor ] = None ,
281
+ attention_mask : Optional [torch .BoolTensor ] = None ,
282
+ adapter_data : Optional [torch .Tensor ] = None ,
283
+ ):
271
284
272
285
hidden_states = self .text_model .model (
273
286
inputs_embeds = inputs_embeds ,
0 commit comments