Image Segmentation
Transformers
Safetensors
actu
feature-extraction
climate
geospatial
remote-sensing
spatiotemporal
multi-modal
earth-observation
time-series
hydrology
custom_code
Instructions to use DarthReca/actu-direction-classification with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use DarthReca/actu-direction-classification with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="DarthReca/actu-direction-classification", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("DarthReca/actu-direction-classification", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2022 Seyong Kim | |
| from typing import Any, Optional, Tuple, Union | |
| import torch | |
| from torch import Tensor, nn, sigmoid, tanh | |
| class ConvGate(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| hidden_channels: int, | |
| kernel_size: Union[Tuple[int, int], int], | |
| padding: Union[Tuple[int, int], int], | |
| stride: Union[Tuple[int, int], int], | |
| bias: bool, | |
| ): | |
| super(ConvGate, self).__init__() | |
| self.conv_x = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=hidden_channels * 4, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| stride=stride, | |
| bias=bias, | |
| ) | |
| self.conv_h = nn.Conv2d( | |
| in_channels=hidden_channels, | |
| out_channels=hidden_channels * 4, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| stride=stride, | |
| bias=bias, | |
| ) | |
| self.bn2d = nn.BatchNorm2d(hidden_channels * 4) | |
| def forward(self, x, hidden_state): | |
| gated = self.conv_x(x) + self.conv_h(hidden_state) | |
| return self.bn2d(gated) | |
| class ConvLSTMCell(nn.Module): | |
| def __init__( | |
| self, in_channels, hidden_channels, kernel_size, padding, stride, bias | |
| ): | |
| super().__init__() | |
| # To check the model structure with tools such as torchinfo, need to wrap | |
| # the custom module with nn.ModuleList | |
| self.gates = nn.ModuleList( | |
| [ConvGate(in_channels, hidden_channels, kernel_size, padding, stride, bias)] | |
| ) | |
| def forward( | |
| self, x: Tensor, hidden_state: Tensor, cell_state: Tensor | |
| ) -> Tuple[Tensor, Tensor]: | |
| gated = self.gates[0](x, hidden_state) | |
| i_gated, f_gated, c_gated, o_gated = gated.chunk(4, dim=1) | |
| i_gated = sigmoid(i_gated) | |
| f_gated = sigmoid(f_gated) | |
| o_gated = sigmoid(o_gated) | |
| cell_state = f_gated.mul(cell_state) + i_gated.mul(tanh(c_gated)) | |
| hidden_state = o_gated.mul(tanh(cell_state)) | |
| return hidden_state, cell_state | |
| class ConvLSTM(nn.Module): | |
| """ConvLSTM module""" | |
| def __init__( | |
| self, | |
| in_channels, | |
| hidden_channels, | |
| kernel_size, | |
| padding, | |
| stride, | |
| bias, | |
| batch_first, | |
| bidirectional, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.hidden_channels = hidden_channels | |
| self.bidirectional = bidirectional | |
| self.batch_first = batch_first | |
| # To check the model structure with tools such as torchinfo, need to wrap | |
| # the custom module with nn.ModuleList | |
| self.conv_lstm_cells = nn.ModuleList( | |
| [ | |
| ConvLSTMCell( | |
| in_channels, hidden_channels, kernel_size, padding, stride, bias | |
| ) | |
| ] | |
| ) | |
| if self.bidirectional: | |
| self.conv_lstm_cells.append( | |
| ConvLSTMCell( | |
| in_channels, hidden_channels, kernel_size, padding, stride, bias | |
| ) | |
| ) | |
| self.batch_size = None | |
| self.seq_len = None | |
| self.height = None | |
| self.width = None | |
| def forward( | |
| self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]] = None | |
| ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: | |
| # size of x: B, T, C, H, W or T, B, C, H, W | |
| x = self._check_shape(x) | |
| hidden_state, cell_state, backward_hidden_state, backward_cell_state = ( | |
| self.init_state(x, state) | |
| ) | |
| output, hidden_state, cell_state = self._forward( | |
| self.conv_lstm_cells[0], x, hidden_state, cell_state | |
| ) | |
| if self.bidirectional: | |
| x = torch.flip(x, [1]) | |
| backward_output, backward_hidden_state, backward_cell_state = self._forward( | |
| self.conv_lstm_cells[1], x, backward_hidden_state, backward_cell_state | |
| ) | |
| output = torch.cat([output, backward_output], dim=-3) | |
| hidden_state = torch.cat([hidden_state, backward_hidden_state], dim=-1) | |
| cell_state = torch.cat([cell_state, backward_cell_state], dim=-1) | |
| return output, (hidden_state, cell_state) | |
| def _forward(self, lstm_cell, x, hidden_state, cell_state): | |
| outputs = [] | |
| for time_step in range(self.seq_len): | |
| x_t = x[:, time_step, :, :, :] | |
| hidden_state, cell_state = lstm_cell(x_t, hidden_state, cell_state) | |
| outputs.append(hidden_state.detach()) | |
| output = torch.stack(outputs, dim=1) | |
| return output, hidden_state, cell_state | |
| def _check_shape(self, x: Tensor) -> Tensor: | |
| if self.batch_first: | |
| batch_size, self.seq_len = x.shape[0], x.shape[1] | |
| else: | |
| batch_size, self.seq_len = x.shape[1], x.shape[0] | |
| x = x.permute(1, 0, 2, 3) | |
| x = torch.swapaxes(x, 0, 1) | |
| self.height = x.shape[-2] | |
| self.width = x.shape[-1] | |
| dim = len(x.shape) | |
| if dim == 4: | |
| x = x.unsqueeze(dim=1) # increase dimension | |
| x = x.view(batch_size, self.seq_len, -1, self.height, self.width) | |
| x = x.contiguous() # Reassign memory location | |
| elif dim <= 3: | |
| raise ValueError( | |
| f"Got {len(x.shape)} dimensional tensor. Input shape unmatched" | |
| ) | |
| return x | |
| def init_state( | |
| self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]] | |
| ) -> Tuple[Union[Tensor, Any], Union[Tensor, Any], Optional[Any], Optional[Any]]: | |
| # If state doesn't enter as input, initialize state to zeros | |
| backward_hidden_state, backward_cell_state = None, None | |
| if state is None: | |
| self.batch_size = x.shape[0] | |
| hidden_state, cell_state = self._init_state(x.dtype, x.device) | |
| if self.bidirectional: | |
| backward_hidden_state, backward_cell_state = self._init_state( | |
| x.dtype, x.device | |
| ) | |
| else: | |
| if self.bidirectional: | |
| hidden_state, hidden_state_back = state[0].chunk(2, dim=-1) | |
| cell_state, cell_state_back = state[1].chunk(2, dim=-1) | |
| else: | |
| hidden_state, cell_state = state | |
| return hidden_state, cell_state, backward_hidden_state, backward_cell_state | |
| def _init_state(self, dtype, device): | |
| self.register_buffer( | |
| "hidden_state", | |
| torch.zeros( | |
| (1, self.hidden_channels, self.height, self.width), | |
| dtype=dtype, | |
| device=device, | |
| ), | |
| ) | |
| self.register_buffer( | |
| "cell_state", | |
| torch.zeros( | |
| (1, self.hidden_channels, self.height, self.width), | |
| dtype=dtype, | |
| device=device, | |
| ), | |
| ) | |
| return self.hidden_state, self.cell_state | |