目录

NLP-36CRF条件随机场-源码解读

【NLP 36、CRF条件随机场 —— 源码解读】


难只是暂时的,顶住压力后,希望能见到你心中的那束太阳

—— 25.3.11

一、CRF —— 条件随机场:

CRF是为了解决,当预测某一个字为一种实体的左边界时,则其右边不可能是其余实体的内部或右边界,我们运用另一个矩阵控制序列前后转移的概率(相关性)

CRF的 本质 是在神经网络中加入一个CRF - 转移矩阵

1.CRF - 转移矩阵

CRF - 转移矩阵: 标签数量 × 标签数量,本质上学习的是字和字之间两两标签转移的概率

https://i-blog.csdnimg.cn/direct/d3ee6a37520948a9b309aeefcf08ac58.png

https://i-blog.csdnimg.cn/direct/850e51f2ab50491792196c2e17a99dee.png

START 和 END 可以看作两个特殊的符号,标记句子的开始和句子的结束


2.发射矩阵

发射矩阵: 对于一句话中的每一个字进行四分类预测,判断其作为词的左右边界、词的内部、单字的概率。

https://i-blog.csdnimg.cn/direct/0de44f878655429cb9f7d444132d9a72.png


3.结合发射矩阵和转移矩阵

CRF - 转移矩阵可以分别学习到某个类别的字转移到其他类别字的概率,然后与 发射矩阵学习到的输入向量过神经网络预测到的两字间的概率值相加,总和进行 比较 ,对输入序列进行预测

CRF - 条件随机场输出的转移矩阵 可以与 向量经过神经网络后得出的发射矩阵 结合 使用,输出一个 更优的 预测结果

https://i-blog.csdnimg.cn/direct/0942d5f1e22147c78d921306b0d94cd4.png

转移矩阵可以影响发射矩阵的结果,相当于在神经网络结构中加入一层神经网络

作用: 规避一些不合理的序列输出


4.CRF —— Loss定义

① 输入序列 X,输出序列为 y的路径分数:A 为转移矩阵(代表前一个字向后一个字转移的概率), P 为发射矩阵(过神经网络的每个字对应的概率值), s(X, y) 代表任意一条路径的正确概率得分

s(X, y) = log(A * P) = logA + logP (这里的路径分数可以看作结合两矩阵,再做 log 运算后的)

https://i-blog.csdnimg.cn/direct/e28324f07d5343e3a51a519c3b691a36.png

② 输入序列X,预测输出序列为y的概率 :对上式做softmax,对 步骤 ① 得到的所有路径分数做归一化

https://i-blog.csdnimg.cn/direct/cd674d509b0a4e8fb70ba34b0937ebd3.png

③ 对上式取log,目标为最大化该值(方便计算,与 p(y | X) 成正比):

https://i-blog.csdnimg.cn/direct/c0d604d6fed74af9aea3d2e77bc4590a.png

依然希望这个 log (p(y | X)) 路径分数是最大的

④ 对上式取相反数做loss,目标为最小化该值:

https://i-blog.csdnimg.cn/direct/b3ce189e898b4af2bbef2c22fbf5e761.png

其他路径的总概率得分之和的 log 值 - 正确路径的总概率得分

CRF会明显拖慢训练速度,以效率的角度考虑可以不使用CRF

序列标注任务需要位置对应

而如果使用Bert模型,则做序列标注任务时,label标签在前后都需要加一个占位符,将Bert模型的CLS和SCP标识符包括

文本分类任务与序列标注任务模型结构的主要区别:pooling 归一化层


二、CRF —— 源码解读

1.初始化CRF模块

初始化CRF模块

  • 检查 num_tags 是否有效。
  • 定义 start_transitions (起始转移分数)、 end_transitions (结束转移分数)和 transitions (标签间转移分数)为可训练参数。
  • 调用 reset_parameters 初始化这些参数。

num_tags: 标签的数量,定义CRF的标签空间大小

batch_first: 输入张量的第一个维度是否为批次大小,控制输入张量的维度顺序

start_transitions: 起始状态的转移分数,大小为num_tags,表示从开始状态到每个标签的转移分数

end_transitions: 结束状态的转移分数,大小为num_tags,表示从每个标签到结束状态的转移分数

transitions: 标签之间的转移分数,大小为num_tags, num_tags,表示标签之间的转移概率

nn.Parameter(): 将张量标记为模型参数,使其在训练过程中可以被优化

参数类型描述
datatorch.Tensor要标记为参数的张量。
requires_gradbool是否计算梯度(默认 True

torch.empty(): 创建一个未初始化的张量

参数类型描述
sizetuple张量的形状。
dtypetorch.dtype张量的数据类型(可选)。
devicetorch.device张量的设备(可选)。

raise ValueError(): 抛出一个值错误异常。

参数类型描述
messagestr错误信息。
    def __init__(self, num_tags: int, batch_first: bool = False) -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
        self.reset_parameters()

2.随机初始化CRF参数

随机初始化CRF的参数

  • 使用均匀分布(范围 [-0.1, 0.1] )初始化 start_transitionsend_transitionstransitions

nn.init.uniform_(): 用均匀分布初始化张量

参数类型描述
tensortorch.Tensor要初始化的张量。
afloat均匀分布的下界。
bfloat均匀分布的上界。
    def reset_parameters(self) -> None:
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)

3.前向计算

Pytorch封装好的CRF模型的forward前向计算过程中,计算的是正确路径的概率,作为Loss使用需要添加负号,用概率的相反数作为损失

计算给定标签序列的对数似然

  • 验证输入张量的形状和有效性。
  • 如果 mask 为空,创建一个全 1 的掩码。
  • 如果 batch_firstTrue ,调整张量的维度顺序。
  • 计算标签序列的分数( numerator )和归一化因子( denominator )。
  • 计算对数似然( llh = numerator - denominator )。
  • 根据 reduction 参数返回结果( summeantoken_mean

emissions 发射分数张量,大小为 (seq_length, batch_size, num_tags)(batch_size, seq_length, num_tags)

tags 标签序列张量,大小为 (seq_length, batch_size)(batch_size, seq_length)

mask 掩码张量,大小为 (seq_length, batch_size)(batch_size, seq_length)

reduction 输出的缩减方式,可选 nonesummeantoken_mean

transpose(): 交换张量的两个维度

参数类型描述
dim0int第一个维度。
dim1int第二个维度。

raise ValueError(): 抛出一个值错误异常。

参数类型描述
messagestr错误信息。

torch.ones_like(): 创建一个与输入张量形状相同的全 1 张量。

参数类型描述
inputtorch.Tensor输入张量。
dtypetorch.dtype张量的数据类型(可选)。
devicetorch.device张量的设备(可选)。

sum(): 计算张量的元素和

参数类型描述
dimint沿指定维度求和(可选)。
keepdimbool是否保持维度(可选)。

float(): 将张量转换为浮点类型。

    def forward(
            self,
            emissions: torch.Tensor,
            tags: torch.LongTensor,
            mask: Optional[torch.ByteTensor] = None,
            reduction: str = 'sum',
    ) -> torch.Tensor:
        self._validate(emissions, tags=tags, mask=mask)
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        numerator = self._compute_score(emissions, tags, mask)
        denominator = self._compute_normalizer(emissions, mask)
        llh = numerator - denominator

        if reduction == 'none':
            return llh
        if reduction == 'sum':
            return llh.sum()
        if reduction == 'mean':
            return llh.mean()
        assert reduction == 'token_mean'
        return llh.sum() / mask.float().sum()

4.维特比算法解码

使用 Viterbi 算法解码最可能的标签序列

  • 验证输入张量的形状和有效性。
  • 如果 mask 为空,创建一个全 1 的掩码。
  • 如果 batch_firstTrue ,调整张量的维度顺序。
  • 调用 _viterbi_decode 解码最可能的标签序列。

emissions 发射分数张量,大小为 (seq_length, batch_size, num_tags)(batch_size, seq_length, num_tags)

mask 掩码张量,大小为 (seq_length, batch_size)(batch_size, seq_length)

transpose(): 交换张量的两个维度。

参数类型描述
dim0int第一个维度。
dim1int第二个维度。
    def decode(self, emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
        self._validate(emissions, mask=mask)
        if mask is None:
            mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        return self._viterbi_decode(emissions, mask)

5.验证输入张量

验证输入张量的形状和有效性

  • 检查 emissions 的维度是否为 3。
  • 检查 emissions 的最后一个维度是否等于 num_tags
  • 检查 emissionstagsmask 的形状是否匹配。
  • 检查 mask 的第一个时间步是否全部为 1。

emissions: 发射分数张量

tags: 标签序列张量

mask: 掩码张量

raise ValueError(): 抛出一个值错误异常。

参数类型描述
messagestr错误信息。

.dim(): 返回张量的维度数。

.size(): 返回张量的形状。

参数类型描述
dimint指定维度(可选)。

tuple(): 将可迭代对象转换为元组

参数类型描述
iterableiterable要转换的可迭代对象。

.shape(): 返回张量的形状(与 .size() 相同)。

    def _validate(
            self,
            emissions: torch.Tensor,
            tags: Optional[torch.LongTensor] = None,
            mask: Optional[torch.ByteTensor] = None) -> None:
        if emissions.dim() != 3:
            raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
        if emissions.size(2) != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {emissions.size(2)}')

        if tags is not None:
            if emissions.shape[:2] != tags.shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')

        if mask is not None:
            if emissions.shape[:2] != mask.shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

6.计算分数

计算给定标签序列的分数

  • 验证输入张量的形状和有效性。
  • 初始化分数为起始转移分数和第一个时间步的发射分数。
  • 遍历时间步,累加转移分数和发射分数。
  • 添加结束转移分数。
  • 返回标签序列的分数。

emissions: 发射分数张量

tags: 标签序列张量

mask: 掩码张量

.dim(): 返回张量的维度数。

.size(): 返回张量的形状。

参数类型描述
dimint指定维度(可选)。

.shape(): 返回张量的形状(与 .size() 相同)。

.float(): 将张量转换为浮点类型。

torch.arange(): 创建一个等差序列张量

参数类型描述
startint起始值(默认 0 )。
endint结束值(不包括)。
stepint步长(默认 1 )。
dtypetorch.dtype张量的数据类型(可选)。
devicetorch.device张量的设备(可选)。
    def _compute_score(
            self, emissions: torch.Tensor, tags: torch.LongTensor,
            mask: torch.ByteTensor) -> torch.Tensor:
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        seq_ends = mask.long().sum(dim=0) - 1
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        score += self.end_transitions[last_tags]

        return score

7.计算归一化因子

计算归一化因子(配分函数)

  • 验证输入张量的形状和有效性。
  • 初始化分数为起始转移分数和第一个时间步的发射分数。
  • 遍历时间步,计算所有可能标签序列的分数。
  • 使用 logsumexp 计算归一化因子。
  • 返回归一化因子。

emissions: 发射分数张量

mask: 掩码张量

.all(): 检查张量中的所有元素是否为真。

参数类型描述
dimint沿指定维度检查(可选)。
keepdimbool是否保持维度(可选)。

.dim(): 返回张量的维度数。

.size(): 返回张量的形状。

参数类型描述
dimint指定维度(可选)。

.shape(): 返回张量的形状(与 .size() 相同)。

torch.logsumexp(): 计算张量的对数求和指数

参数类型描述
inputtorch.Tensor输入张量。
dimint沿指定维度计算(可选)。
keepdimbool是否保持维度(可选)。

torch.where(): 根据条件选择元素。

参数类型描述
conditiontorch.Tensor条件张量。
xtorch.Tensor条件为真时的值。
ytorch.Tensor条件为假时的值。
    def _compute_normalizer(
            self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length = emissions.size(0)

        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            broadcast_score = score.unsqueeze(2)
            broadcast_emissions = emissions[i].unsqueeze(1)
            next_score = broadcast_score + self.transitions + broadcast_emissions
            next_score = torch.logsumexp(next_score, dim=1)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        score += self.end_transitions
        return torch.logsumexp(score, dim=1)

8.解码标签序列

使用 Viterbi 算法解码最可能的标签序列

  • 验证输入张量的形状和有效性。
  • 初始化分数为起始转移分数和第一个时间步的发射分数。
  • 遍历时间步,计算所有可能标签序列的分数,并记录最大分数的索引。
  • 添加结束转移分数。
  • 回溯 Viterbi 路径,找到最可能的标签序列。
  • 返回最可能的标签序列。

emissions: 发射分数张量

mask: 掩码张量

.dim(): 返回张量的维度数。

.size(): 返回张量的形状。

参数类型描述
dimint指定维度(可选)。

.shape(): 返回张量的形状(与 .size() 相同)。

unsqueeze(): 在指定维度上增加一个大小为 1 的维度。

参数类型描述
dimint要增加维度的位置。

max(): 返回张量的最大值,或沿指定维度的最大值及其索引。

参数类型描述
dimint沿指定维度计算最大值(可选)。
keepdimbool是否保持维度(可选)。

torch.where(): 根据条件选择元素。

参数类型描述
conditiontorch.Tensor条件张量。
xtorch.Tensor条件为真时的值。
ytorch.Tensor条件为假时的值。

long().sum(): 将张量转换为长整型并求和

参数类型描述
dimint沿指定维度求和(可选)。
keepdimbool是否保持维度(可选)。

append(): 在列表末尾添加元素

参数类型描述
itemany要添加的元素。

reverse(): 反转列表中的元素顺序。

    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor) -> List[List[int]]:
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length, batch_size = mask.shape

        score = self.start_transitions + emissions[0]
        history = []

        for i in range(1, seq_length):
            broadcast_score = score.unsqueeze(2)
            broadcast_emission = emissions[i].unsqueeze(1)
            next_score = broadcast_score + self.transitions + broadcast_emission
            next_score, indices = next_score.max(dim=1)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        score += self.end_transitions

        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list

9.源码

核心算法

  1. CRF 的对数似然计算

    • 对数似然 = 标签序列的分数 - 归一化因子。
    • 标签序列的分数 = 起始转移分数 + 发射分数 + 转移分数 + 结束转移分数。
    • 归一化因子 = 所有可能标签序列的分数之和(使用 logsumexp 计算)。
  2. Viterbi 算法

    • 动态规划算法,用于找到最可能的标签序列。
    • 在每个时间步,计算所有可能标签序列的分数,并记录最大分数的索引。
    • 回溯路径,得到最可能的标签序列。
__version__ = '0.7.2'

from typing import List, Optional

import torch
import torch.nn as nn


class CRF(nn.Module):
    """Conditional random field.

    This module implements a conditional random field [LMP01]_. The forward computation
    of this class computes the log likelihood of the given sequence of tags and
    emission score tensor. This class also has `~CRF.decode` method which finds
    the best tag sequence given an emission score tensor using `Viterbi algorithm`_.

    Args:
        num_tags: Number of tags.
        batch_first: Whether the first dimension corresponds to the size of a minibatch.

    Attributes:
        start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
            ``(num_tags,)``.
        end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
            ``(num_tags,)``.
        transitions (`~torch.nn.Parameter`): Transition score tensor of size
            ``(num_tags, num_tags)``.


    .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
       "Conditional random fields: Probabilistic models for segmenting and
       labeling sequence data". *Proc. 18th International Conf. on Machine
       Learning*. Morgan Kaufmann. pp. 282–289.

    .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
    """

    def __init__(self, num_tags: int, batch_first: bool = False) -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))

        # 随机初始化参数
        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize the transition parameters.

        The parameters will be initialized randomly from a uniform distribution
        between -0.1 and 0.1.
        """
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(num_tags={self.num_tags})'

    def forward(
            self,
            emissions: torch.Tensor,
            tags: torch.LongTensor,
            mask: Optional[torch.ByteTensor] = None,
            reduction: str = 'sum',
    ) -> torch.Tensor:
        """Compute the conditional log likelihood of a sequence of tags given emission scores.

        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            tags (`~torch.LongTensor`): Sequence of tags tensor of size
                ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
            reduction: Specifies  the reduction to apply to the output:
                ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
                ``sum``: the output will be summed over batches. ``mean``: the output will be
                averaged over batches. ``token_mean``: the output will be averaged over tokens.

        Returns:
            `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
            reduction is ``none``, ``()`` otherwise.
        """
        self._validate(emissions, tags=tags, mask=mask)
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        # shape: (batch_size,)
        numerator = self._compute_score(emissions, tags, mask)
        # shape: (batch_size,)
        denominator = self._compute_normalizer(emissions, mask)
        # shape: (batch_size,)
        llh = numerator - denominator

        if reduction == 'none':
            return llh
        if reduction == 'sum':
            return llh.sum()
        if reduction == 'mean':
            return llh.mean()
        assert reduction == 'token_mean'
        return llh.sum() / mask.float().sum()

    def decode(self, emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
        """Find the most likely tag sequence using Viterbi algorithm.

        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.

        Returns:
            List of list containing the best tag sequence for each batch.
        """
        self._validate(emissions, mask=mask)
        if mask is None:
            mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        return self._viterbi_decode(emissions, mask)

    def _validate(
            self,
            emissions: torch.Tensor,
            tags: Optional[torch.LongTensor] = None,
            mask: Optional[torch.ByteTensor] = None) -> None:
        if emissions.dim() != 3:
            raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
        if emissions.size(2) != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {emissions.size(2)}')

        if tags is not None:
            if emissions.shape[:2] != tags.shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')

        if mask is not None:
            if emissions.shape[:2] != mask.shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

    def _compute_score(
            self, emissions: torch.Tensor, tags: torch.LongTensor,
            mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        # shape: (batch_size,)
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score

    def _compute_normalizer(
            self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length = emissions.size(0)

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return torch.logsumexp(score, dim=1)

    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor) -> List[List[int]]:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history = []

        # score is a tensor of size (batch_size, num_tags) where for every batch,
        # value at column j stores the score of the best tag sequence so far that ends
        # with tag j
        # history saves where the best tags candidate transitioned from; this is used
        # when we trace back the best tag sequence

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Now, compute the best path for each sample

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            # Reverse the order because we start from the last timestep
            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list