Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Issues
  • #89
Closed
Open
Issue created May 10, 2022 by Administrator@rootOwner

assert key_padding_mask.size(1) == src_len in 350M model

Created by: Mrs-Hudson

As I am unable to run the CLI with the 350M model (#86 (closed)), I am running the generation task using an adhoc script shared below. However, I get an assert error with the same, the same data works with next token prediction ( #73 (closed))

Code

`import os

from transformers import GPT2Tokenizer from metaseq import checkpoint_utils import torch import queue import pkg_resources import random import shutil import threading

from metaseq import options from metaseq.dataclass.configs import MetaseqConfig from metaseq.dataclass.utils import convert_namespace_to_omegaconf from metaseq.distributed import utils as dist_utils from metaseq.hub_utils import GeneratorInterface from metaseq.service.queue import PriorityQueueRingShard from metaseq.service.workers import WorkItem from metaseq.service.constants import ( MAX_SEQ_LEN, MAX_BATCH_TOKENS, DEFAULT_PORT, TOTAL_WORLD_SIZE, CHECKPOINT_LOCAL, CHECKPOINT_FOLDER, LAUNCH_ARGS, ) from metaseq.service.utils import get_my_ip, encode_fn, build_logger from metaseq.service.responses import OAIResponse

logger = build_logger() path = "/home/azureuser/350_model_info"

""" $ ls path vocab.json merges.txt reshard.pt """

tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer") tokenizer.save_pretrained(path)

paths = [os.path.join(path, "reshard.pt")]

checkpoint = checkpoint_utils.load_model_ensemble_and_task( paths, arg_overrides={ "vocab_filename": os.path.join(path, "vocab.json"), "merges_filename": os.path.join(path, "merges.txt"), } )

model = checkpoint[0][0].eval()

forward passes

def single_batch_forward_logits(prompts): input_ids = tokenizer(prompts, return_tensors="pt").input_ids input_ids = torch.cat([torch.tensor([[2]]), input_ids], dim=-1) logits = model(input_ids)[0] return logits

def _copy_checkpoint_cache(): if CHECKPOINT_LOCAL == CHECKPOINT_FOLDER: # user didn't have a local SSD return if os.path.exists(os.path.dirname(CHECKPOINT_LOCAL)): logger.info("Local checkpoint copy already exists, skipping copy") else: logger.info( f"Making a local copy of the checkpoint. {CHECKPOINT_FOLDER} -> {CHECKPOINT_LOCAL}" ) shutil.copytree(CHECKPOINT_FOLDER, os.path.dirname(CHECKPOINT_LOCAL))

def worker_main(cfg1: MetaseqConfig, namespace_args=None): shutil.copytree(CHECKPOINT_FOLDER, os.path.dirname(CHECKPOINT_LOCAL)) # disable multithreading in tokenizers and torch, as different Flask threads # may then fight for resources. os.environ["TOKENIZERS_PARALLELISM"] = "false" torch.set_num_threads(1) global generator global MODE # make sure generations are stochastic since we have many workers torch.manual_seed(random.randint(1, 20000)) torch.cuda.manual_seed(random.randint(1, 20000)) MODE = "worker" cfg = cfg1 print("In worker main") generator = GeneratorInterface(cfg) models = generator.load_model() # noqa: F841 print("\n Model loaded \n") logger.info(f"loaded model {cfg.distributed_training.distributed_rank}")

prompts = [
"Today is a beautiful day and I want to",
"In the city of",
"Paris is the capital of France and",
"Computers and mobile phones have taken",
"LinkedIn is a great company and I am "
]
inputs = [encode_fn(generator, p) for p in prompts]
min_tokens = [5,5,5,5,5]
max_tokens = [32,32,32,32,32]
print(inputs)
retval = generator.generate(inputs, min_tokens,max_tokens)
print(retval)

def cli_main(): """ Hosted version of the web UI for generation. """ _copy_checkpoint_cache()

global port, MODE, cfg
parser = options.get_generation_parser()

# dumb defaults overriding
parser.set_defaults(lr_scheduler=None, criterion=None)
flat_launch_args = []
for s in LAUNCH_ARGS:
    flat_launch_args += s.split()
args = options.parse_args_and_arch(parser, input_args=flat_launch_args)
args.data = os.path.dirname(args.path)  # hardcode the data arg
port = DEFAULT_PORT
cfg = convert_namespace_to_omegaconf(args)
cfg.distributed_training.distributed_world_size = TOTAL_WORLD_SIZE
print("Calling main\n")
dist_utils.call_main(cfg, worker_main, namespace_args=args)

if name == "main": cli_main()`

Stacktrace: Traceback (most recent call last): File "inf_generate.py", line 134, in <module> cli_main() File "inf_generate.py", line 131, in cli_main dist_utils.call_main(cfg, worker_main, namespace_args=args) File "/home/azureuser/metaseq/metaseq/distributed/utils.py", line 256, in call_main return _spawn_helper(main, cfg, kwargs) File "/home/azureuser/metaseq/metaseq/distributed/utils.py", line 234, in _spawn_helper retval = distributed_main(-1, main, cfg, kwargs) File "/home/azureuser/metaseq/metaseq/distributed/utils.py", line 203, in distributed_main main(cfg, **kwargs) File "inf_generate.py", line 108, in worker_main retval = generator.generate(inputs, min_tokens,max_tokens) File "/home/azureuser/metaseq/metaseq/hub_utils.py", line 604, in generate translations = self.task.inference_step(generator, self.models, batch) File "/home/azureuser/metaseq/metaseq/tasks/language_modeling.py", line 326, in inference_step return generator.generate( File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context return func(*args, **kwargs) File "/home/azureuser/metaseq/metaseq/sequence_generator.py", line 93, in generate return self._generate(sample, **kwargs) File "/home/azureuser/metaseq/metaseq/sequence_generator.py", line 286, in _generate model_out = self.model.decoder( File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/azureuser/metaseq/metaseq/models/transformer.py", line 639, in forward x, extra = self.extract_features( File "/home/azureuser/metaseq/metaseq/models/transformer.py", line 664, in extract_features return self.extract_features_scriptable( File "/home/azureuser/metaseq/metaseq/models/transformer.py", line 728, in extract_features_scriptable x, layer_attn, _, l_aux_i = layer( File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/azureuser/metaseq/metaseq/modules/transformer_layer.py", line 509, in forward x, attn = self.forward_attention( File "/home/azureuser/metaseq/metaseq/modules/transformer_layer.py", line 422, in forward_attention x, attn = self.self_attn( File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/azureuser/metaseq/metaseq/modules/multihead_attention.py", line 331, in forward assert key_padding_mask.size(1) == src_len AssertionError

I also printed the key_padding_mask sizes and the src_len `key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 10 10 5

key padding size is: 2 5 21 11 5`

metaseq Version (e.g., 1.0 or master): 0.0.1

PyTorch Version (e.g., 1.0): '1.10.1+cu113'

OS (e.g., Linux): Linux NAME="Ubuntu" VERSION="18.04.6 LTS (Bionic Beaver)" ID=ubuntu ID_LIKE=debian PRETTY_NAME="Ubuntu 18.04.6 LTS" VERSION_ID="18.04" HOME_URL="https://www.ubuntu.com/" SUPPORT_URL="https://help.ubuntu.com/" BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/" PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy" VERSION_CODENAME=bionic UBUNTU_CODENAME=bionic

How you installed metaseq (pip, source): Same as setup instructions

Build command you used (if compiling from source): Same as setup instructions

Python version: 3.8.5

CUDA/cuDNN version: (azureml_py38) azureuser@rparik4:~/metaseq$ nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2021 NVIDIA Corporation Built on Sun_Feb_14_21:12:58_PST_2021 Cuda compilation tools, release 11.2, V11.2.152 Build cuda_11.2.r11.2/compiler.29618528_0

GPU models and configuration: Azure compute node with 8 gpus Virtual machine size Standard_ND40rs_v2 (40 cores, 672 GB RAM, 2900 GB disk) Processing unit GPU - 8 x NVIDIA Tesla V100

Assignee
Assign to
Time tracking