Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Issues
  • #73
Closed
Open
Issue created May 09, 2022 by Administrator@rootOwner

125m checkpoint outputting gibberish

Created by: patrickvonplaten

Converting the sharded checkpoints of 125m to a singleton checkpoint with https://github.com/facebookresearch/metaseq/pull/60:

    $ ls 125m
    dict.txt
    gpt2-merges.txt
    gpt2-vocab.json
    reshard-model_part-0.pt
    reshard-model_part-1.pt
    $ python -m metaseq.scripts.convert_to_singleton 125m

gives a new

restored.pt

file.

I then transformed the checkpoint into the same format as 350m to test some generation on it:

import torch
orig_state = torch.load("./reshard-model_part-0.pt")
model = torch.load("./restored.pt")

orig_state["model"] = model  # this format allows one to use the standard `checkpoint_utils.load_model_ensemble_and_task` function
orig_state["cfg"]["model"]._name = "transformer_lm"  # we change the architecture name to "transformer_lm" to be able to run it in a non-CUDA environment
torch.save(orig_state, "./reshard.pt")

I tried running an inference example on the model to see whether the generation works as expected. Here the code:

import os

from transformers import GPT2Tokenizer
from metaseq import checkpoint_utils
import torch

path = "/home/patrick/add_opt"

"""
$ ls path
vocab.json
merges.txt
reshard.pt
"""


tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
tokenizer.save_pretrained(path)

paths = [os.path.join(path, "reshard.pt")]

checkpoint = checkpoint_utils.load_model_ensemble_and_task(
    paths,
    arg_overrides={
        "vocab_filename": os.path.join(path, "vocab.json"),
        "merges_filename": os.path.join(path, "merges.txt"),
    }
)

model = checkpoint[0][0].eval()


# forward passes
def single_batch_forward_logits(prompts):
    input_ids = tokenizer(prompts, return_tensors="pt").input_ids
    input_ids = torch.cat([torch.tensor([[2]]), input_ids], dim=-1)
    logits = model(input_ids)[0]
    return logits


prompts = [
    "Today is a beautiful day and I want to",
    "In the city of",
    "Paris is the capital of France and",
    "Computers and mobile phones have taken",
]


print("Next word generation")
for prompt in prompts:
    print("-------------")
    print(f"Prompt: {prompt}...\n")
    logits = single_batch_forward_logits(prompt)
    pred_next_token = torch.argmax(logits[0, -1], -1)
    next_token = tokenizer.convert_ids_to_tokens([pred_next_token])
    next_token = next_token[0].replace("Ġ", "")
    print(f"Next word: {next_token}")
    print("-------------")

This sadly gives gibberish:

Next word generation
-------------
Prompt: Today is a beautiful day and I want to...

Next word: Robbins
-------------
-------------
Prompt: In the city of...

Next word: of
-------------
-------------
Prompt: Paris is the capital of France and...

Next word: Robbins
-------------
-------------
Prompt: Computers and mobile phones have taken...

Next word: Robbins
-------------

Note that this script works perfectly fine with the 350m checkpoint.

@stephenroller - any ideas?

Assignee
Assign to
Time tracking