load_state_dict error when loading from MP models resharded from a singleton
Created by: jxmsML
❓ Questions and Help
I'm running interactive_cli on a reshard.pt with (resharded from a singleton consolidated file) and notice the model has a load_state error (missing qkv_proj.weight, 'qkv_proj.bias): model is transformer_lm_megatron, but in the consolidated.pt the state['model']has the unconcatenated statesq_proj, k_proj, v_proj` (saved when consolidating the original fsdp models into a singleton) https://github.com/facebookresearch/metaseq/blob/4629c56c467c1c40ef518a86f12799062e2551fa/metaseq/distributed/stitch_fsdp_ckpt.py#L287-L290)
I then notice that Alpa has the same renaming logic. However, they concatenate the q_proj, v_proj, k_proj into qkv_proj when loading from consolidated.pt singleton file.
So far I can't find such logic of concatenating the q_proj, v_proj, k_proj into qkv_proj in metaseq.
Before asking:
- search the issues.
- search the docs.
What is your question?
Code
What have you tried?
What's your environment?
- metaseq Version (e.g., 1.0 or master):
- PyTorch Version (e.g., 1.0)
- OS (e.g., Linux):
- How you installed metaseq (
pip, source): - Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information: