jayelm / gisting

Learning to Compress Prompts with Gist Tokens - https://arxiv.org/abs/2304.08467

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

why there's a parameter "offset"?

pengfeiwu1999 opened this issue · comments

commented

query_states, key_states = apply_rotary_pos_emb(

the apply_rotary_pos_emb() function does not accept the offset argument?

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can squeeze them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

The gist_offset parameter is needed because, when caching an instruction, the model needs to know the length of the cached instruction to apply the position embeddings correctly. (It's not needed for T5 due to T5's relative position embedding scheme). You can see how the position embeddings are shifted here:

gisting/src/gist_llama.py

Lines 114 to 125 in acd78b4

offset = 0
if gist_offset is not None:
# XXX: THIS WILL BREAK FOR BATCH SIZE > 1.
offset = gist_offset[0].item()
return (
self.cos_cached[:, :, offset : offset + seq_len, ...].to(
dtype=x.dtype, device=x.device
),
self.sin_cached[:, :, offset : offset + seq_len, ...].to(
dtype=x.dtype, device=x.device
),
)

apply_rotary_pos_emb does not require the gist offset argument because it transforms the cos and sin tensors which already have the offset applied.

Note the offset parameter is not used during standard training or evaluation, because we don't actually modify any sequence lengths—the model gets the entire instruction/input/output in one go, with attention masking used to control compression, and the position embeddings are correctly applied with the original instruction length.

The offset parameter is only used in compress.py—the GistActivations class accepts a "gist offset" argument which records the length of the instruction before caching:

gisting/src/gist_llama.py

Lines 633 to 656 in acd78b4

@torch.no_grad()
def get_gist_activations(
self,
input_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
attention_mask_gist: torch.FloatTensor,
gist_token: int,
num_gist_tokens: int,
cache_all: bool = False,
) -> GistActivations:
model_outputs = self.model(
input_ids,
attention_mask=attention_mask,
attention_mask_gist=attention_mask_gist,
output_hidden_states=True,
use_cache=True,
)
return GistActivations.from_model_outputs(
model_outputs=model_outputs,
input_ids=input_ids,
gist_token=gist_token,
num_gist_tokens=num_gist_tokens,
cache_all=cache_all,
)

gisting/src/gist_caching.py

Lines 112 to 113 in acd78b4

# Keep track of gist starts for the batch.
gist_starts.append(gist_start)

commented

thanks for reply, but my question is: apply_rotary_pos_emb() function doesn't need the off-set para but in the code I mentioned above, in the 206 line of /gisting/src/gist_llama.py file , the function use off_set as input parameter,Doesn't that make an error?

commented

The gist_offset parameter is needed because, when caching an instruction, the model needs to know the length of the cached instruction to apply the position embeddings correctly. (It's not needed for T5 due to T5's relative position embedding scheme). You can see how the position embeddings are shifted here:

gisting/src/gist_llama.py

Lines 114 to 125 in acd78b4

offset = 0
if gist_offset is not None:
# XXX: THIS WILL BREAK FOR BATCH SIZE > 1.
offset = gist_offset[0].item()
return (
self.cos_cached[:, :, offset : offset + seq_len, ...].to(
dtype=x.dtype, device=x.device
),
self.sin_cached[:, :, offset : offset + seq_len, ...].to(
dtype=x.dtype, device=x.device
),
)

apply_rotary_pos_emb does not require the gist offset argument because it transforms the cos and sin tensors which already have the offset applied.

Note the offset parameter is not used during standard training or evaluation, because we don't actually modify any sequence lengths—the model gets the entire instruction/input/output in one go, with attention masking used to control compression, and the position embeddings are correctly applied with the original instruction length.

The offset parameter is only used in compress.py—the GistActivations class accepts a "gist offset" argument which records the length of the instruction before caching:

gisting/src/gist_llama.py

Lines 633 to 656 in acd78b4

@torch.no_grad()
def get_gist_activations(
self,
input_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
attention_mask_gist: torch.FloatTensor,
gist_token: int,
num_gist_tokens: int,
cache_all: bool = False,
) -> GistActivations:
model_outputs = self.model(
input_ids,
attention_mask=attention_mask,
attention_mask_gist=attention_mask_gist,
output_hidden_states=True,
use_cache=True,
)
return GistActivations.from_model_outputs(
model_outputs=model_outputs,
input_ids=input_ids,
gist_token=gist_token,
num_gist_tokens=num_gist_tokens,
cache_all=cache_all,
)

gisting/src/gist_caching.py

Lines 112 to 113 in acd78b4

# Keep track of gist starts for the batch.
gist_starts.append(gist_start)

so when I run the llama model, it occurs "apply_rotary_pos_emb() got an unexpected keyword argument 'offset'"

commented

The gist_offset parameter is needed because, when caching an instruction, the model needs to know the length of the cached instruction to apply the position embeddings correctly. (It's not needed for T5 due to T5's relative position embedding scheme). You can see how the position embeddings are shifted here:

gisting/src/gist_llama.py

Lines 114 to 125 in acd78b4

offset = 0
if gist_offset is not None:
# XXX: THIS WILL BREAK FOR BATCH SIZE > 1.
offset = gist_offset[0].item()
return (
self.cos_cached[:, :, offset : offset + seq_len, ...].to(
dtype=x.dtype, device=x.device
),
self.sin_cached[:, :, offset : offset + seq_len, ...].to(
dtype=x.dtype, device=x.device
),
)

apply_rotary_pos_emb does not require the gist offset argument because it transforms the cos and sin tensors which already have the offset applied.

Note the offset parameter is not used during standard training or evaluation, because we don't actually modify any sequence lengths—the model gets the entire instruction/input/output in one go, with attention masking used to control compression, and the position embeddings are correctly applied with the original instruction length.

The offset parameter is only used in compress.py—the GistActivations class accepts a "gist offset" argument which records the length of the instruction before caching:

gisting/src/gist_llama.py

Lines 633 to 656 in acd78b4

@torch.no_grad()
def get_gist_activations(
self,
input_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
attention_mask_gist: torch.FloatTensor,
gist_token: int,
num_gist_tokens: int,
cache_all: bool = False,
) -> GistActivations:
model_outputs = self.model(
input_ids,
attention_mask=attention_mask,
attention_mask_gist=attention_mask_gist,
output_hidden_states=True,
use_cache=True,
)
return GistActivations.from_model_outputs(
model_outputs=model_outputs,
input_ids=input_ids,
gist_token=gist_token,
num_gist_tokens=num_gist_tokens,
cache_all=cache_all,
)

gisting/src/gist_caching.py

Lines 112 to 113 in acd78b4

# Keep track of gist starts for the batch.
gist_starts.append(gist_start)

the error occurs in my llama training stage is
File "/data/wupf/gisting/src/gist_llama.py", line 206, in forward
query_states, key_states = apply_rotary_pos_emb(
TypeError: apply_rotary_pos_emb() got an unexpected keyword argument 'offset'

Are you using the version of transformers specified in requirements.txt, specifically commit fb366b9a? You'll see that the function signature is different:

https://github.com/huggingface/transformers/blob/fb366b9a/src/transformers/models/llama/modeling_llama.py#L136-L141

(and more generally if you run into other issues it could be because you're not using the package versions specified in requirements.txt)

commented

Are you using the version of transformers specified in requirements.txt, specifically commit fb366b9a? You'll see that the function signature is different:

https://github.com/huggingface/transformers/blob/fb366b9a/src/transformers/models/llama/modeling_llama.py#L136-L141

(and more generally if you run into other issues it could be because you're not using the package versions specified in requirements.txt)

I try to install the transformers version fb366b9a, but I can't install it on my server, All other packages are installed according to requirement except transformer

Unfortunately the codebase is only verified to work with commit fb366b9a. You might be able to get around this specific issue by just pasting in the apply_rotary_pos_emb function from the link above instead of importing it from modeling_llama, but I can't guarantee you won't run into additional issues.

commented

Unfortunately the codebase is only verified to work with commit fb366b9a. You might be able to get around this specific issue by just pasting in the apply_rotary_pos_emb function from the link above instead of importing it from modeling_llama, but I can't guarantee you won't run into additional issues.

ok I fixed it, thanks!

Great!