125m checkpoint outputting gibberish
Created by: patrickvonplaten
Converting the sharded checkpoints of 125m to a singleton checkpoint with https://github.com/facebookresearch/metaseq/pull/60:
$ ls 125m
dict.txt
gpt2-merges.txt
gpt2-vocab.json
reshard-model_part-0.pt
reshard-model_part-1.pt
$ python -m metaseq.scripts.convert_to_singleton 125m
gives a new
restored.pt
file.
I then transformed the checkpoint into the same format as 350m to test some generation on it:
import torch
orig_state = torch.load("./reshard-model_part-0.pt")
model = torch.load("./restored.pt")
orig_state["model"] = model # this format allows one to use the standard `checkpoint_utils.load_model_ensemble_and_task` function
orig_state["cfg"]["model"]._name = "transformer_lm" # we change the architecture name to "transformer_lm" to be able to run it in a non-CUDA environment
torch.save(orig_state, "./reshard.pt")
I tried running an inference example on the model to see whether the generation works as expected. Here the code:
import os
from transformers import GPT2Tokenizer
from metaseq import checkpoint_utils
import torch
path = "/home/patrick/add_opt"
"""
$ ls path
vocab.json
merges.txt
reshard.pt
"""
tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
tokenizer.save_pretrained(path)
paths = [os.path.join(path, "reshard.pt")]
checkpoint = checkpoint_utils.load_model_ensemble_and_task(
paths,
arg_overrides={
"vocab_filename": os.path.join(path, "vocab.json"),
"merges_filename": os.path.join(path, "merges.txt"),
}
)
model = checkpoint[0][0].eval()
# forward passes
def single_batch_forward_logits(prompts):
input_ids = tokenizer(prompts, return_tensors="pt").input_ids
input_ids = torch.cat([torch.tensor([[2]]), input_ids], dim=-1)
logits = model(input_ids)[0]
return logits
prompts = [
"Today is a beautiful day and I want to",
"In the city of",
"Paris is the capital of France and",
"Computers and mobile phones have taken",
]
print("Next word generation")
for prompt in prompts:
print("-------------")
print(f"Prompt: {prompt}...\n")
logits = single_batch_forward_logits(prompt)
pred_next_token = torch.argmax(logits[0, -1], -1)
next_token = tokenizer.convert_ids_to_tokens([pred_next_token])
next_token = next_token[0].replace("Ġ", "")
print(f"Next word: {next_token}")
print("-------------")
This sadly gives gibberish:
Next word generation
-------------
Prompt: Today is a beautiful day and I want to...
Next word: Robbins
-------------
-------------
Prompt: In the city of...
Next word: of
-------------
-------------
Prompt: Paris is the capital of France and...
Next word: Robbins
-------------
-------------
Prompt: Computers and mobile phones have taken...
Next word: Robbins
-------------
Note that this script works perfectly fine with the 350m checkpoint.
@stephenroller - any ideas?