-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP and tensor-paralllel compatibility #36
Comments
@asahni04 Hi! Thanks for the support! Q1: a weight like this from a 3D conv or 2D conv needs to be put in what group? torch.Size([1024, 3, 1, 4, 4]) Q2: what about norm.scales Q3: can my embedding layers have both norms and MLP's or shall i put them in mlp set for the mlp only? Q4: can MLP layers have norms in between? or shall i put only the MLP parts in the MLP set? the MLP layers have norms in say the feedforward blocks Q5: where shall i put adaptive layers normalization layers, they have MLP to compute scale and shift Q6: say my output layer be a block composed of norms linears etc. do i need to split them into different sets like mlp_names etc.? In summary, you might try the following:
|
@zyushun running into this error when trying with FSDP2 only in torchtitan
|
would something like this work ?
|
also after making this change i barely see any memory savings over adamW @zyushun, any suggestions why this happens? did you notice memory savings with torchtitan?? Edit: saw some issues stating adam-mini is fp32 and adamw is bf16? |
@nighting0le01 if you are up for it, you can debug the memory using the memory snapshot tool: For example:
You can upload your snapshot to https://pytorch.org/memory_viz to visualize it. You will be able to see all of the memory allocations with stack traces from where they were allocated. From this, you can see which optimizer states are taking more memory and check how that matches against your expectation. |
@zyushun hi i noticed something similar to also even i don't see any difference in memory switching b/w AdamW and mini. any recommendations? i will try to profile but have you noticed any of the above issues and errors?
|
Hi @asahni04 and @nighting0le01 ! Thanks for the feedback. I have three questions that require confirmation from you guys.
|
hi @zyushun thanks for your help here,
|
Hi @nighting0le01 Thanks for the follow-up.
![]()
To sum up, I cannot tell why is there no memory saving so far... Perhaps the model is very small and you are using too many cards to shard it? |
Hi, @zyushun
Adam mini looks promising but
in regards to #14
could you please help me debug and setup the following @zyushun
a weight like this from a 3D conv or 2D conv needs to be put in what group?
torch.Size([1024, 3, 1, 4, 4])
also what about norm.scales
can my embedding layers have both norms and MLP's or shall i put them in mlp set for the mlp only?
can MLP layers have norms in between? or shall i put only the MLP parts in the MLP set? the MLP layers have norms in say the feedforward blocks
where shall i put adaptive layers normalization layers, they have MLP to compute scale and shift
say my output layer be a block composed of norms linears etc. do i need to split them into different sets like mlp_names etc.?
also i get something like
below from File "/opt/conda/lib/python3.11/site-packages/adam_mini/adam_mini.py", line 320, in step
[rank4]: super().optimizer_step(*args, **kwargs)
[rank4]: File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
[rank4]: optimizer.step(closure=optimizer_closure)
[rank4]: File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
[rank4]: step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
[rank4]: return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
[rank4]: return optimizer.step(closure=closure, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
[rank4]: return func.get(opt, opt.class)(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank4]: out = func(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank4]: return func(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/adam_mini/adam_mini.py", line 320, in step
[rank4]: stepsize = ((1 / bias_correction_1) / h).view(neuron_per_gpu, 1)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
[rank4]: return disable_fn(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank4]: return fn(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 340, in torch_dispatch
[rank4]: return DTensor._op_dispatcher.dispatch(
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 169, in dispatch
[rank4]: self.sharding_propagator.propagate(op_info)
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 201, in propagate
[rank4]: OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 46, in call
[rank4]: return self.cache(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 214, in propagate_op_sharding_non_cached
[rank4]: out_tensor_meta = self._propagate_tensor_meta(op_schema)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 124, in _propagate_tensor_meta
[rank4]: fake_out = op_schema.op(*fake_args, **fake_kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 716, in call
[rank4]: return self._op(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank4]: return fn(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in torch_dispatch
[rank4]: return self.dispatch(func, types, args, kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
[rank4]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
[rank4]: output = self._dispatch_impl(func, types, args, kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2013, in _dispatch_impl
[rank4]: r = func(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 716, in call
[rank4]: return self._op(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: RuntimeError: shape '[1024, 1]' is invalid for input of size 16384
The text was updated successfully, but these errors were encountered: