Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP and tensor-paralllel compatibility #36

Open
asahni04 opened this issue Dec 16, 2024 · 10 comments
Open

FSDP and tensor-paralllel compatibility #36

asahni04 opened this issue Dec 16, 2024 · 10 comments

Comments

@asahni04
Copy link

asahni04 commented Dec 16, 2024

Hi, @zyushun

Adam mini looks promising but
in regards to #14
could you please help me debug and setup the following @zyushun

a weight like this from a 3D conv or 2D conv needs to be put in what group?
torch.Size([1024, 3, 1, 4, 4])

also what about norm.scales

can my embedding layers have both norms and MLP's or shall i put them in mlp set for the mlp only?

can MLP layers have norms in between? or shall i put only the MLP parts in the MLP set? the MLP layers have norms in say the feedforward blocks

where shall i put adaptive layers normalization layers, they have MLP to compute scale and shift

say my output layer be a block composed of norms linears etc. do i need to split them into different sets like mlp_names etc.?

also i get something like
below from File "/opt/conda/lib/python3.11/site-packages/adam_mini/adam_mini.py", line 320, in step

[rank4]: super().optimizer_step(*args, **kwargs)
[rank4]: File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
[rank4]: optimizer.step(closure=optimizer_closure)
[rank4]: File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
[rank4]: step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
[rank4]: return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
[rank4]: return optimizer.step(closure=closure, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
[rank4]: return func.get(opt, opt.class)(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank4]: out = func(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank4]: return func(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/adam_mini/adam_mini.py", line 320, in step
[rank4]: stepsize = ((1 / bias_correction_1) / h).view(neuron_per_gpu, 1)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
[rank4]: return disable_fn(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank4]: return fn(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 340, in torch_dispatch
[rank4]: return DTensor._op_dispatcher.dispatch(
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 169, in dispatch
[rank4]: self.sharding_propagator.propagate(op_info)
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 201, in propagate
[rank4]: OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 46, in call
[rank4]: return self.cache(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 214, in propagate_op_sharding_non_cached
[rank4]: out_tensor_meta = self._propagate_tensor_meta(op_schema)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 124, in _propagate_tensor_meta
[rank4]: fake_out = op_schema.op(*fake_args, **fake_kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 716, in call
[rank4]: return self._op(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank4]: return fn(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in torch_dispatch
[rank4]: return self.dispatch(func, types, args, kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
[rank4]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
[rank4]: output = self._dispatch_impl(func, types, args, kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2013, in _dispatch_impl
[rank4]: r = func(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 716, in call
[rank4]: return self._op(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: RuntimeError: shape '[1024, 1]' is invalid for input of size 16384

@zyushun
Copy link
Owner

zyushun commented Dec 16, 2024

@asahni04 Hi! Thanks for the support!

Q1: a weight like this from a 3D conv or 2D conv needs to be put in what group? torch.Size([1024, 3, 1, 4, 4])
A1: should put in none of the groups and shall go to the “else” condition and be treated as “other blocks” (line 315). This will treat the whole 4D conv layer as a whole and will not further partition it. 
To do this, just make sure NO keywords in the name of your conv layer occur in the set of mlp_names, wqk_names, etc. (see an example below)

Q2: what about norm.scales
A2: also should be treated as “other blocks”.

Q3: can my embedding layers have both norms and MLP's or shall i put them in mlp set for the mlp only?
A3: Yes you can, and the MLP (or rigorously speaking, “nn.Linear()” modules) shall be put in the mlp set, norms (or rigorously speaking, “LayerNorm()” modules) should be treated as “other blocks”.

Q4: can MLP layers have norms in between? or shall i put only the MLP parts in the MLP set? the MLP layers have norms in say the feedforward blocks
A4: Yes, you can have norms in MLP layer, but they should not occur in the MLP set and shall be treated as “other blocks”

Q5: where shall i put adaptive layers normalization layers, they have MLP to compute scale and shift
A5: Same as above. Please separate norm and nn.linear().

Q6: say my output layer be a block composed of norms linears etc. do i need to split them into different sets like mlp_names etc.?
A6: Same as above. Please separate norm and nn.linear().

In summary, you might try the following:

optimizer  = Adam_mini(…)
optimizer.mlp_names = {‘your keywords in nn.linear() that do NOT occur in the norm layer or conv layer’} 
optimizer.emebd_names = {‘your keywords in nn.Emebedding() that do NOT occur in the norm layer or conv layer’} 

@nighting0le01
Copy link

nighting0le01 commented Dec 19, 2024

@zyushun running into this error when trying with FSDP2 only in torchtitan


[rank5]:  <code above cannot be shared> -->  optimizer.step
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
[rank5]:     step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank5]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
[rank5]:     return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
[rank5]:     return optimizer.step(closure=closure, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
[rank5]:     return func.__get__(opt, opt.__class__)(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank5]:     out = func(*args, **kwargs)
[rank5]:           ^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank5]:     return func(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/adam_mini/adam_mini.py", line 293, in step
[rank5]:     state["vmean"] = torch.zeros_like(state["m"][0:state["neuron_per_gpu"], 0:1],
[rank5]:                                       ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
[rank5]:     return disable_fn(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank5]:     return fn(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
[rank5]:     return DTensor._op_dispatcher.dispatch(
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 169, in dispatch
[rank5]:     self.sharding_propagator.propagate(op_info)
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 201, in propagate
[rank5]:     OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank5]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 46, in __call__
[rank5]:     return self.cache(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 450, in propagate_op_sharding_non_cached
[rank5]:     raise NotImplementedError(
[rank5]: NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered.

@nighting0le01
Copy link

nighting0le01 commented Dec 19, 2024

@awgu

@nighting0le01
Copy link

                        neuron_per_gpu = state["m"].size(0)
                        state["vmean"] = p.new_zeros((neuron_per_gpu, 1))

would something like this work ?

                    # state["vmean"] = torch.zeros_like(state["m"][0:state["neuron_per_gpu"], 0:1],
                    #                                   memory_format=torch.preserve_format)

@nighting0le01
Copy link

nighting0le01 commented Dec 19, 2024

also after making this change i barely see any memory savings over adamW @zyushun, any suggestions why this happens? did you notice memory savings with torchtitan??

Edit: saw some issues stating adam-mini is fp32 and adamw is bf16?

@awgu
Copy link

awgu commented Dec 19, 2024

@nighting0le01 if you are up for it, you can debug the memory using the memory snapshot tool:

For example:

# Include this somewhere early in your script
torch.cuda.memory._record_memory_history()

# After your optimizer step, run something like this
snapshot = torch.cuda.memory._snapshot()
with open("snapshot.pickle", "wb") as f:
    pickle.dump(snapshot, f)

You can upload your snapshot to https://pytorch.org/memory_viz to visualize it. You will be able to see all of the memory allocations with stack traces from where they were allocated. From this, you can see which optimizer states are taking more memory and check how that matches against your expectation.

@asahni04
Copy link
Author

@zyushun hi i noticed something similar to also even i don't see any difference in memory switching b/w AdamW and mini. any recommendations? i will try to profile but have you noticed any of the above issues and errors?

neuron_per_gpu = state["m"].size(0)
state["vmean"] = p.new_zeros((neuron_per_

@zyushun
Copy link
Owner

zyushun commented Dec 21, 2024

Hi @asahni04 and @nighting0le01 ! Thanks for the feedback.

I have three questions that require confirmation from you guys.

  1. Do you still encounter errors after you make this code change?
neuron_per_gpu = state["m"].size(0)
state["vmean"] = p.new_zeros((neuron_per_gpu, 1))
 # state["vmean"] = torch.zeros_like(state["m"][0:state["neuron_per_gpu"], 0:1],
#                                   memory_format=torch.preserve_format)
  1. In your case, does the 2D matrix flatten into 1D vectors? You can check this by reading the log from Adam-mini. I would hope the answer is no since we only support 2D matrix.

  2. Did you double-check that both AdamW and Adam-mini use the same precision such as float32?

@nighting0le01
Copy link

hi @zyushun thanks for your help here,

  1. i can get it to run after making the code change but literally no memory savings. i read through the logs and no weight matrix are flattened it seems.
  2. did you have any comparisons b/w the optimizers? is AdamW supposed to run in bf16? i can debug and check what precision Adam-mini runs in but do we support lower precsion for Adam-mini??

@zyushun
Copy link
Owner

zyushun commented Dec 21, 2024

Hi @nighting0le01 Thanks for the follow-up.

  1. Yes, we have compared and we can see a reasonable amount of memory drop. See the wandb log below for Llama 2-13B.
memory_compare1221
  1. I think by default of Torchtitan, both optimizers run in bf16 (for weight and gradient ) +float 32 (for m and v). See this issue here issue. You can also double-check using the same methods in that issue.

  2. "do we support lower precision for Adam-mini??" We have not implemented it yet. This is another ongoing research.

To sum up, I cannot tell why is there no memory saving so far... Perhaps the model is very small and you are using too many cards to shard it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants