Running SDy Passes On Non-Solved Graphs In Tt-mlir
This article discusses a specific challenge encountered while uplifting Shardy within the Tenstorrent tt-mlir project and the proposed solution. Specifically, it addresses the issue of running SDy passes only on non-solved graphs. This is crucial for maintaining the functionality of the compiler while integrating new features.
Background and Context
Guys, let's dive into the context! The team is working on uplifting Shardy, a component within the tt-mlir project. A recent change to Shardy's insert-explicit-reshard
pass introduces a reshard operation between a function argument and the corresponding sdy.manual_computation
op. This occurs when the sharding annotations between the function arguments don't align. To illustrate, consider the following example:
Example Scenario
Before the uplift, the code might look like this:
func.func public @main(%arg0: tensor<1x1024x128x1024xf32> {ttcore.shard_status = #ttcore.shard_status<unsharded>}) -> (tensor<1x1024x128x1024xf32> {jax.result_info = "", sdy.sharding = #sdy.sharding<@mesh, [{?}, {"x", ?}, {?}, {?}]>, ttcore.shard_status = #ttcore.shard_status<unsharded>}) {
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{}, {"x"}, {}, {}]>] out_shardings=[<@mesh, [{}, {"x"}, {}, {}]>] manual_axes={"x", "y"} (%arg1: tensor<1x512x128x1024xf32>) {
%1 = stablehlo.negate %arg1 : tensor<1x512x128x1024xf32>
sdy.return %1 : tensor<1x512x128x1024xf32>
} : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32>
return %0 : tensor<1x1024x128x1024xf32>
}
Notice that %arg0
in @main
lacks sharding annotations, while %arg0
within the manual_computation
op has them. After applying the insert-explicit-reshard
pass, the code transforms to:
func.func public @main(%arg0: tensor<1x1024x128x1024xf32> {ttcore.shard_status = #ttcore.shard_status<unsharded>}) -> (tensor<1x1024x128x1024xf32> {jax.result_info = "", sdy.sharding = #sdy.sharding<@mesh, [{?}, {"x", ?}, {?}, {?}]>, ttcore.shard_status = #ttcore.shard_status<unsharded>}) {
%0 = sdy.reshard %arg0 <@mesh, [{}, {"x"}, {}, {}]> : tensor<1x1024x128x1024xf32>
%1 = sdy.manual_computation(%0) in_shardings=[<@mesh, [{}, {"x"}, {}, {}]>] out_shardings=[<@mesh, [{}, {"x"}, {}, {}]>] manual_axes={"x", "y"} (%arg1: tensor<1x512x128x1024xf32>) {
%2 = stablehlo.negate %arg1 : tensor<1x512x128x1024xf32>
sdy.return %2 : tensor<1x512x128x1024xf32>
} : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32>
return %1 : tensor<1x1024x128x1024xf32>
}
An explicit reshard operation is inserted for %arg0
in @main
to align with the in_shardings
of the manual_computation
op. This new behavior, while intended to improve sharding consistency, introduces a problem.
The Problem: Impact on Solved Graphs
The core issue is that inserting explicit reshards disrupts the compiler's handling of solved graphs. Specifically, the insert-explicit-reshard
pass breaks the existing workflow for graphs that have already been optimized and for which a sharding strategy has been determined. These graphs, referred to as "solved graphs", are crucial for performance and efficiency. Introducing reshards at this stage can invalidate prior optimizations and lead to unexpected behavior.
To be precise, this change impacts how the compiler processes graphs that have already undergone significant optimization steps. The insertion of explicit reshards can potentially interfere with the carefully crafted sharding strategy, which is a critical aspect of optimizing performance on Tenstorrent's hardware. Therefore, a solution is needed to mitigate this disruption while still allowing the Shardy uplift to proceed.
Proposed Solution: Selective Pass Execution
To address this issue, the proposed solution involves wrapping the Shardy passes within a Tenstorrent-specific pass. This wrapper will act as a gatekeeper, ensuring that these passes are executed only on non-solved graphs. This approach allows the team to continue with the Shardy uplift while preserving the existing functionality for solved graphs. By selectively applying these passes, the risk of disrupting already optimized graphs is minimized.
Defining Solved Graphs
Currently, a graph is considered "solved" if it contains a manual_computation
operation with manual axes. This criterion is defined in the ShardyUtils.cpp
file within the tt-mlir repository. Essentially, the presence of a manual computation operation with specified manual axes indicates that a sharding strategy has been explicitly defined for that graph. Therefore, any further modifications to the sharding, such as the insertion of explicit reshards, should be carefully controlled.
In simpler terms, if the compiler has already figured out how to shard the data across the hardware (indicated by the presence of manual_computation
with manual axes), we don't want the new pass to mess with it unless absolutely necessary. This is why we're only running the Shardy passes on graphs that haven't been solved yet.
Implementation Details
The implementation of this solution involves creating a new pass that encapsulates the existing Shardy passes. This new pass will then include a check to determine whether the input graph is considered "solved" based on the criteria mentioned above. If the graph is deemed solved, the Shardy passes will be skipped. Otherwise, they will be executed as intended. This ensures that the new resharding logic is applied only to graphs that haven't yet undergone sharding optimization.
Benefits of this Approach
This approach offers several key benefits:
- Preserves Existing Functionality: By selectively running the Shardy passes, the existing behavior for solved graphs is maintained. This prevents any unexpected performance regressions or functional issues.
- Enables Shardy Uplift: The team can continue with the Shardy uplift without being blocked by the solved graph issue. This allows for the integration of new features and improvements to the compiler.
- Provides a Stop-Gap Solution: This approach serves as a temporary solution while a more comprehensive approach for handling explicit reshards in solved graphs is developed.
Think of it like this: it's a temporary fix that allows us to keep moving forward while we figure out the best long-term solution. We're not sacrificing current functionality for new features; we're carefully balancing both.
Long-Term Solution
While the selective pass execution provides an immediate solution, a long-term strategy is necessary to fully address the handling of explicit reshards in solved graphs. The team plans to modify the compiler to correctly handle these reshards, ensuring that they are properly integrated into the optimization process. This will involve analyzing the impact of reshards on existing sharding strategies and developing algorithms to optimize their placement and execution.
Future Enhancements
Future enhancements may include:
- Reshard Optimization: Developing techniques to optimize the placement and execution of reshard operations, minimizing their overhead.
- Sharding Strategy Adaptation: Implementing mechanisms to adapt existing sharding strategies in response to the introduction of reshards.
- Graph Analysis: Enhancing graph analysis capabilities to better understand the impact of reshards on performance.
Ultimately, the goal is to seamlessly integrate explicit reshards into the compilation process without disrupting the performance of solved graphs. This will require a deeper understanding of the interactions between reshards and sharding strategies, as well as the development of sophisticated optimization techniques.
Alternatives Considered
While the proposed solution is the preferred approach, other alternatives were considered. However, they were deemed less suitable for various reasons. One alternative might have been to revert the changes to the insert-explicit-reshard
pass. However, this would have blocked the Shardy uplift and prevented the integration of new features. Another alternative could have been to immediately implement a comprehensive solution for handling reshards in solved graphs. However, this would have required a significant amount of development effort and delayed the uplift.
In short, the team weighed the pros and cons of various approaches and determined that selectively running the Shardy passes was the most pragmatic solution for the current situation.
Conclusion
The challenge of running SDy passes only on non-solved graphs in Tenstorrent's tt-mlir highlights the complexities of compiler development and the need for careful consideration when introducing new features. The proposed solution, which involves wrapping the Shardy passes and selectively executing them based on the graph's solved status, provides a practical approach for mitigating the immediate issue while allowing the Shardy uplift to proceed. In the long term, the team plans to develop a more comprehensive solution for handling explicit reshards, ensuring seamless integration and optimal performance. By addressing this challenge in a thoughtful and strategic manner, Tenstorrent is demonstrating its commitment to building a robust and efficient compiler for its cutting-edge hardware.