Skip to content

Commit 37787ac

Browse files
authored
Merge branch 'vllm-project:main' into main
2 parents 8431e82 + c66c7f8 commit 37787ac

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

vllm/model_executor/models/paligemma.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from vllm.config import CacheConfig, MultiModalConfig
1010
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
1111
from vllm.logger import init_logger
12-
from vllm.model_executor.layers.linear import ColumnParallelLinear
1312
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1413
from vllm.model_executor.layers.quantization.base_config import (
1514
QuantizationConfig)
@@ -133,12 +132,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
133132
def __init__(self, vision_hidden_size: int, projection_dim: int):
134133
super().__init__()
135134

136-
self.linear = ColumnParallelLinear(vision_hidden_size,
137-
projection_dim,
138-
bias=True)
135+
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
139136

140137
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
141-
hidden_states, _ = self.linear(image_features)
138+
hidden_states = self.linear(image_features)
142139
return hidden_states
143140

144141

vllm/worker/tpu_model_runner.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828

2929
logger = init_logger(__name__)
3030

31-
_PAD_SLOT_ID = -1 # NOTE(woosuk): In PyTorch XLA, index -1 is ignored.
31+
# Here we utilize the behavior that out-of-bound index is ignored.
32+
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
33+
_PAD_SLOT_ID = 1_000_000_000
3234
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
3335
_ENABLE_TOP_P = False
3436
# FIXME(woosuk): A temporary hack to support `n > 1`.
@@ -414,10 +416,7 @@ def _prepare_sample(
414416
best_of = []
415417
for seq_group_metadata in seq_group_metadata_list:
416418
sampling_params = seq_group_metadata.sampling_params
417-
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
418-
# low temperature. This is not accurate.
419-
t.append(sampling_params.temperature
420-
if sampling_params.temperature >= 1e-5 else 1e-5)
419+
t.append(sampling_params.temperature)
421420
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
422421
raise NotImplementedError(
423422
"Top-p sampling is currently disabled for the TPU backend "
@@ -678,13 +677,23 @@ def forward(
678677
hidden_states = hidden_states.flatten(0, 1)
679678
logits = self.model.compute_logits(hidden_states, sampling_metadata)
680679

681-
logits = logits / t.unsqueeze(dim=1)
680+
# Argmax sampling.
681+
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
682+
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
683+
684+
# Zero temperature means greedy decoding. Avoid division by zero.
685+
nonzero_t = torch.where(t != 0, t, 1.0)
686+
logits = logits / nonzero_t.unsqueeze(dim=1)
682687
if _ENABLE_TOP_P:
683688
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
689+
690+
# Random sampling.
684691
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
685-
next_token_ids = torch.multinomial(probs,
686-
num_samples,
687-
replacement=True)
692+
sampled_token_ids = torch.multinomial(probs,
693+
num_samples,
694+
replacement=True)
695+
next_token_ids = torch.where(t != 0, sampled_token_ids,
696+
argmax_token_ids)
688697
return next_token_ids
689698

690699

0 commit comments

Comments
 (0)