diff --git a/mmseg/models/decode_heads/dnl_head.py b/mmseg/models/decode_heads/dnl_head.py index ab53d9a..dabf154 100644 --- a/mmseg/models/decode_heads/dnl_head.py +++ b/mmseg/models/decode_heads/dnl_head.py @@ -26,8 +26,13 @@ class DisentangledNonLocal2d(NonLocal2d): pairwise_weight = torch.matmul(theta_x, phi_x) if self.use_scale: # theta_x.shape[-1] is `self.inter_channels` - pairwise_weight /= theta_x.shape[-1]**0.5 - pairwise_weight /= self.temperature + pairwise_weight /= torch.tensor( + theta_x.shape[-1], + dtype=torch.float, + device=pairwise_weight.device)**torch.tensor( + 0.5, device=pairwise_weight.device) + pairwise_weight /= torch.tensor( + self.temperature, device=pairwise_weight.device) pairwise_weight = pairwise_weight.softmax(dim=-1) return pairwise_weight