Running SDy Passes On Non-Solved Graphs In Tt-mlir

by ADMIN 51 views

This article discusses a specific challenge encountered while uplifting Shardy within the Tenstorrent tt-mlir project and the proposed solution. Specifically, it addresses the issue of running SDy passes only on non-solved graphs. This is crucial for maintaining the functionality of the compiler while integrating new features.

Background and Context

Guys, let's dive into the context! The team is working on uplifting Shardy, a component within the tt-mlir project. A recent change to Shardy's insert-explicit-reshard pass introduces a reshard operation between a function argument and the corresponding sdy.manual_computation op. This occurs when the sharding annotations between the function arguments don't align. To illustrate, consider the following example:

Example Scenario

Before the uplift, the code might look like this:

func.func public @main(%arg0: tensor<1x1024x128x1024xf32> {ttcore.shard_status = #ttcore.shard_status<unsharded>}) -> (tensor<1x1024x128x1024xf32> {jax.result_info = "", sdy.sharding = #sdy.sharding<@mesh, [{?}, {"x", ?}, {?}, {?}]>, ttcore.shard_status = #ttcore.shard_status<unsharded>}) {
  %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{}, {"x"}, {}, {}]>] out_shardings=[<@mesh, [{}, {"x"}, {}, {}]>] manual_axes={"x", "y"} (%arg1: tensor<1x512x128x1024xf32>) {
    %1 = stablehlo.negate %arg1 : tensor<1x512x128x1024xf32>
    sdy.return %1 : tensor<1x512x128x1024xf32>
  } : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32>
  return %0 : tensor<1x1024x128x1024xf32>
}

Notice that %arg0 in @main lacks sharding annotations, while %arg0 within the manual_computation op has them. After applying the insert-explicit-reshard pass, the code transforms to:

func.func public @main(%arg0: tensor<1x1024x128x1024xf32> {ttcore.shard_status = #ttcore.shard_status<unsharded>}) -> (tensor<1x1024x128x1024xf32> {jax.result_info = "", sdy.sharding = #sdy.sharding<@mesh, [{?}, {"x", ?}, {?}, {?}]>, ttcore.shard_status = #ttcore.shard_status<unsharded>}) {
  %0 = sdy.reshard %arg0 <@mesh, [{}, {"x"}, {}, {}]> : tensor<1x1024x128x1024xf32>
  %1 = sdy.manual_computation(%0) in_shardings=[<@mesh, [{}, {"x"}, {}, {}]>] out_shardings=[<@mesh, [{}, {"x"}, {}, {}]>] manual_axes={"x", "y"} (%arg1: tensor<1x512x128x1024xf32>) {
    %2 = stablehlo.negate %arg1 : tensor<1x512x128x1024xf32>
    sdy.return %2 : tensor<1x512x128x1024xf32>
  } : (tensor<1x1024x128x1024xf32>) -> tensor<1x1024x128x1024xf32>
  return %1 : tensor<1x1024x128x1024xf32>
}

An explicit reshard operation is inserted for %arg0 in @main to align with the in_shardings of the manual_computation op. This new behavior, while intended to improve sharding consistency, introduces a problem.

The Problem: Impact on Solved Graphs

The core issue is that inserting explicit reshards disrupts the compiler's handling of solved graphs. Specifically, the insert-explicit-reshard pass breaks the existing workflow for graphs that have already been optimized and for which a sharding strategy has been determined. These graphs, referred to as "solved graphs", are crucial for performance and efficiency. Introducing reshards at this stage can invalidate prior optimizations and lead to unexpected behavior.

To be precise, this change impacts how the compiler processes graphs that have already undergone significant optimization steps. The insertion of explicit reshards can potentially interfere with the carefully crafted sharding strategy, which is a critical aspect of optimizing performance on Tenstorrent's hardware. Therefore, a solution is needed to mitigate this disruption while still allowing the Shardy uplift to proceed.

Proposed Solution: Selective Pass Execution

To address this issue, the proposed solution involves wrapping the Shardy passes within a Tenstorrent-specific pass. This wrapper will act as a gatekeeper, ensuring that these passes are executed only on non-solved graphs. This approach allows the team to continue with the Shardy uplift while preserving the existing functionality for solved graphs. By selectively applying these passes, the risk of disrupting already optimized graphs is minimized.

Defining Solved Graphs

Currently, a graph is considered "solved" if it contains a manual_computation operation with manual axes. This criterion is defined in the ShardyUtils.cpp file within the tt-mlir repository. Essentially, the presence of a manual computation operation with specified manual axes indicates that a sharding strategy has been explicitly defined for that graph. Therefore, any further modifications to the sharding, such as the insertion of explicit reshards, should be carefully controlled.

In simpler terms, if the compiler has already figured out how to shard the data across the hardware (indicated by the presence of manual_computation with manual axes), we don't want the new pass to mess with it unless absolutely necessary. This is why we're only running the Shardy passes on graphs that haven't been solved yet.

Implementation Details

The implementation of this solution involves creating a new pass that encapsulates the existing Shardy passes. This new pass will then include a check to determine whether the input graph is considered "solved" based on the criteria mentioned above. If the graph is deemed solved, the Shardy passes will be skipped. Otherwise, they will be executed as intended. This ensures that the new resharding logic is applied only to graphs that haven't yet undergone sharding optimization.

Benefits of this Approach

This approach offers several key benefits:

  • Preserves Existing Functionality: By selectively running the Shardy passes, the existing behavior for solved graphs is maintained. This prevents any unexpected performance regressions or functional issues.
  • Enables Shardy Uplift: The team can continue with the Shardy uplift without being blocked by the solved graph issue. This allows for the integration of new features and improvements to the compiler.
  • Provides a Stop-Gap Solution: This approach serves as a temporary solution while a more comprehensive approach for handling explicit reshards in solved graphs is developed.

Think of it like this: it's a temporary fix that allows us to keep moving forward while we figure out the best long-term solution. We're not sacrificing current functionality for new features; we're carefully balancing both.

Long-Term Solution

While the selective pass execution provides an immediate solution, a long-term strategy is necessary to fully address the handling of explicit reshards in solved graphs. The team plans to modify the compiler to correctly handle these reshards, ensuring that they are properly integrated into the optimization process. This will involve analyzing the impact of reshards on existing sharding strategies and developing algorithms to optimize their placement and execution.

Future Enhancements

Future enhancements may include:

  • Reshard Optimization: Developing techniques to optimize the placement and execution of reshard operations, minimizing their overhead.
  • Sharding Strategy Adaptation: Implementing mechanisms to adapt existing sharding strategies in response to the introduction of reshards.
  • Graph Analysis: Enhancing graph analysis capabilities to better understand the impact of reshards on performance.

Ultimately, the goal is to seamlessly integrate explicit reshards into the compilation process without disrupting the performance of solved graphs. This will require a deeper understanding of the interactions between reshards and sharding strategies, as well as the development of sophisticated optimization techniques.

Alternatives Considered

While the proposed solution is the preferred approach, other alternatives were considered. However, they were deemed less suitable for various reasons. One alternative might have been to revert the changes to the insert-explicit-reshard pass. However, this would have blocked the Shardy uplift and prevented the integration of new features. Another alternative could have been to immediately implement a comprehensive solution for handling reshards in solved graphs. However, this would have required a significant amount of development effort and delayed the uplift.

In short, the team weighed the pros and cons of various approaches and determined that selectively running the Shardy passes was the most pragmatic solution for the current situation.

Conclusion

The challenge of running SDy passes only on non-solved graphs in Tenstorrent's tt-mlir highlights the complexities of compiler development and the need for careful consideration when introducing new features. The proposed solution, which involves wrapping the Shardy passes and selectively executing them based on the graph's solved status, provides a practical approach for mitigating the immediate issue while allowing the Shardy uplift to proceed. In the long term, the team plans to develop a more comprehensive solution for handling explicit reshards, ensuring seamless integration and optimal performance. By addressing this challenge in a thoughtful and strategic manner, Tenstorrent is demonstrating its commitment to building a robust and efficient compiler for its cutting-edge hardware.