11 분 소요

도면 내 텍스트 영역 탐지에 이어 이번 포스트에서는 탐지 영역 내 텍스트를 인식하는 기술에 대해 소개드립니다.

📋 Table of Contents

  1. 텍스트 인식 배경
  2. 텍스트 인식 모델의 프레임워크 분석 및 설계
  3. 텍스트 인식 모델 구축 과정
  4. 결론

1. 텍스트 인식 배경

OCR(Optical Character Recognition)은 인쇄된 광학 문자 이미지를 기계가 읽을 수 있는 디지털 텍스트 데이터로 변환하는 기술입니다. 더 나아가 Real world에서 발생할 수 있는 다양한 노이즈와 왜곡이 섞여 있는 텍스트를 인식하는 STR(Scene Text Recognition)도 많은 발전을 하고 있습니다. 이 기술들은 디지털 서류 전환, 기밀문서 마스킹, 표지판 및 간판 인식 등 다양한 분야에서 유용하게 활용되는 기술이죠.

현재 OCR은 주로 공문서, 신분증, 영수증, 서적 등과 같이 정형화된 포맷이나 사전적 의미가 있는 텍스트를 디지털화하는 데 많이 적용되고 있습니다. 예를 들어 아래 영수증과 같이 정형화된 포맷의 경우, 특정 영역에는 날짜, 시간, 품목, 가격 등과 같은 정해진 타입의 텍스트가 작성될 것이라 예상할 수 있습니다. 또한, 사전적 의미가 있는 텍스트의 경우, 잘못 인식된 경우에도 문맥을 고려하여 오타를 수정할 수 있습니다.

하지만 설계 도면의 텍스트는 주로 특정 건설 사업에서 정의된 코드로 구성되어 있습니다. 이러한 코드들은 특수 기호들과 조합되어 최대 20자 이상의 단어로 표현될 수 있습니다. 따라서 설계 도면의 특성에 맞는 텍스트 인식 모델을 개발하는 것이 필요합니다.

OCR-Example
[ 영수증 OCR 인식 예시 ]

출처: 이미지 출처

2. 텍스트 인식 모델의 프레임워크 분석 및 설계

텍스트 인식 모델의 성능을 평가하고 비교하는 일은 그 어떤 분야에서도 쉽지 않습니다. 2019년에 발표된 논문 “What Is Wrong With Scene Text Recognition Model Comparisons? Dataset and Model Analysis” 에서는 이러한 어려움을 극복하기 위한 프레임워크를 제안하고 있습니다. 이 논문은 프레임워크의 네 가지 핵심 스테이지(Transformation, Feature Extraction, Sequence Modeling, Prediction)를 통해 텍스트 인식 모델을 분석하였습니다.
아래 그림은 기존에 제안된 텍스트 인식 모델을 네 가지 스테이지로 구분지어 정의하고 일관된 벤치마크 데이터셋으로 성능을 평가한 결과 자료입니다.

Previously-Proposed-Combinations
[ 기존 텍스트 인식 모델의 스테이지 조합 ]

출처: 이미지 출처

각 스테이지의 특성을 이해하고 이를 조합하여 적합한 모델을 개발한다면, 설계 도면 내 텍스트를 인식하는 데 더 나은 성능을 기대할 수 있습니다. 따라서 설계 도면 내 텍스트를 인식 분야에서는 어떤 stage의 조합이 최적의 조합일지 스테이지별 분석 및 조합을 시도하였습니다.

Transformation(Trans.): 변환 스테이지

이 단계에서는 Spatial Transformation Network(STN)의 변형 중 하나인 Thin-Plate Spline(TPS) 변환을 사용합니다. TPS 변환은 다양한 형태의 기하학적 왜곡을 쉽게 표현할 수 있어, 복잡한 텍스트 형태를 보정하는 데 사용됩니다. 특히 곡선 또는 기울어진 텍스트와 같은 어려운 형태의 텍스트를 정규화하여 후속 단계의 처리를 향상시킵니다.

Feature Extraction(Feat.): 특징 추출 스테이지

이 단계에서는 입력 데이터로부터 중요한 시각 정보를 추출하여 학습된 패턴을 기반으로 새로운 데이터를 분류하거나 인식하는 데 필수적입니다. 일반적인 CNN 아키텍처는 합성곱 신경망 구조를 기반이며 텍스트의 크기, 폰트, 배경 등 관련 없는 정보를 제외하고 텍스트 이미지의 정보를 분류하는 작업을 합니다. 연구 결과, ResNet 아키텍처가 VGG와 RCNN보다 우수한 성능을 보였습니다. 이는 텍스트 이미지의 정보를 분류하는 작업에서 뛰어난 효과를 발휘합니다.

Sequence Modeling (Seq.): 시퀀스 모델링 스테이지

이 단계에서는 텍스트를 캐릭터 단위로 인식하는 것이 아닌 텍스트의 시퀀스를 문맥 정보로써 인식합니다. 따라서, 모델은 추출된 특징을 시퀀스로 재구성하고, 문맥 정보를 포함하여 더 나은 시퀀스를 형성하는 데 목표를 두고 있습니다. 여기서 BiLSTM을 사용하는 경우 더 넓은 문맥을 이해하고 문자를 예측할 수 있습니다.

Prediction (Pred.): 예측 스테이지

이전 단계를 통해 얻은 특징들을 단어들의 시퀀스로 예측하는 작업을 수행합니다. 여기서 두 가지 선택 사항이 있습니다. 첫 번째로 Connectionist Temporal Classification(CTC)는 중복 문자와 공백을 제거하여 가변 길이의 출력 문자 시퀀스를 예측하는 방법입니다. 두 번째로 Attention-based Sequence Prediction (Attn) 은 주요 정보 흐름을 파악하여 출력 문자 시퀀스를 예측합니다. 특히 Attn은 문자 간의 의존성을 학습할 수 있도록 돕기 때문에 문자 일부가 가려져있거나 누락된 경우에 더 나은 결과를 보입니다.

위 네 가지 stage 분석을 토대로 TPS-ResNet-None-CTC 조합으로 텍스트 인식 모델을 설계하고, 다른 조합 텍스트 인식 모델과 비교 실험을 했습니다.

3. 텍스트 인식 모델 구축 과정

데이터 수집 및 전처리

텍스트 탐지 단계에서 언급한 것처럼, 건설 프로젝트와 설계 벤더마다 텍스트의 폰트와 크기 등 스타일이 다양합니다. 이에 대응하기 위해 가능한 많은 다양한 건설 프로젝트의 도면에서 텍스트 이미지를 수집했습니다. 총 13개의 건설 프로젝트에서 약 3,363장의 도면을 수집하고, 수집된 도면에서 추출한 텍스트 영역 이미지는 총 376,656장이었습니다. 텍스트의 캐릭터 클래스는 영문, 숫자 그리고 20개의 특수 문자로 구성하였습니다.

데이터 다양성을 확보하기 위해 다양한 폰트와 스타일로 작성된 텍스트 이미지를 생성했습니다. 이러한 가짜 이미지는 실제 수집한 데이터셋과 유사한 양을 생성하였습니다. 이 가짜 이미지는 실제 데이터와 비슷한 특성을 가지고 있어, 딥러닝 모델 학습에 활용됩니다. 이를 통해 모델은 다양한 스타일의 텍스트를 인식하는 데 더 강력해집니다.

수집한 텍스트 이미지는 길이가 한 글자부터 20자 이상까지 다양합니다. 이를 처리하기 위해 텍스트 이미지의 가로와 세로 비율을 기준으로 짧은 텍스트와 긴 텍스트로 분류했습니다. 이렇게 길이에 따라 나눈 텍스트 이미지 데이터셋은 길이에 따라 두 개의 텍스트 인식 모델을 학습할 때 활용됩니다.

모델 학습

딥러닝 모델 학습을 위해 제안하는 TPS-ResNet-None-CTC 조합의 프레임워크로 코드를 수정하여 학습을 수행합니다.

  • Transformation(Trans.): TPS
    코드 접기/펼치기
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    class TPS_SpatialTransformerNetwork(nn.Module):
        """ RARE의 Rectification Network, 즉 TPS 기반 STN"""
    
        def __init__(self, F, I_size, I_r_size, I_channel_num=1):
            """ RARE TPS를 기반으로 함
            input:
                batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
                I_size : 입력 이미지 I의 (높이, 너비)
                I_r_size : 변환된 이미지 I_r의 (높이, 너비)
                I_channel_num : 입력 이미지 I의 채널 수
            output:
                batch_I_r: 변환된 이미지 [batch_size x I_channel_num x I_r_height x I_r_width]
            """
            super(TPS_SpatialTransformerNetwork, self).__init__()
            self.F = F
            self.I_size = I_size
            self.I_r_size = I_r_size  # = (I_r_height, I_r_width)
            self.I_channel_num = I_channel_num
            self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
            self.GridGenerator = GridGenerator(self.F, self.I_r_size)
    
        def forward(self, batch_I):
            batch_C_prime = self.LocalizationNetwork(batch_I)  # batch_size x K x 2
            build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime)  # batch_size x n (= I_r_width x I_r_height) x 2
            build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
              
            # grid_sample 함수를 사용하여 이미지를 변환합니다.
            if torch.__version__ > "1.2.0":
                batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
            else:
                batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
    
            return batch_I_r
    
    
    class LocalizationNetwork(nn.Module):
        """ RARE의 Localization Network, 입력 이미지 I에서 C' (K x 2)를 예측합니다. """
    
        def __init__(self, F, I_channel_num):
            super(LocalizationNetwork, self).__init__()
            self.F = F
            self.I_channel_num = I_channel_num
            # Convolutional layers와 Fully connected layers로 구성된 신경망입니다.
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
                          bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
                nn.MaxPool2d(2, 2),  # batch_size x 64 x I_height/2 x I_width/2
                nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
                nn.MaxPool2d(2, 2),  # batch_size x 128 x I_height/4 x I_width/4
                nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
                nn.MaxPool2d(2, 2),  # batch_size x 256 x I_height/8 x I_width/8
                nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
                nn.AdaptiveAvgPool2d(1)  # batch_size x 512
            )
            # 예측된 C' 좌표를 계산하기 위한 fully connected layers입니다.
            self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
            self.localization_fc2 = nn.Linear(256, self.F * 2)
    
            # LocalizationNetwork의 fc2 초기화
            self.localization_fc2.weight.data.fill_(0)
            """ RARE 논문의 Fig. 6 (a) 참조 """
            ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
            ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
            ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
            ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
            ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
            initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
            self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
    
        def forward(self, batch_I):
            batch_size = batch_I.size(0)
            features = self.conv(batch_I).view(batch_size, -1)
            batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2)
            return batch_C_prime
    
    
    class GridGenerator(nn.Module):
        """ RARE의 Grid Generator, T와 P를 곱해서 P_prime을 생성합니다. """
    
        def __init__(self, F, I_r_size):
            """ P_hat과 inv_delta_C를 미리 생성합니다. """
            super(GridGenerator, self).__init__()
            self.eps = 1e-6
            self.I_r_height, self.I_r_width = I_r_size
            self.F = F
            self.C = self._build_C(self.F)  # F x 2
            self.P = self._build_P(self.I_r_width, self.I_r_height)
            # multi-gpu를 위해 buffer로 등록합니다.
            self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float())  # F+3 x F+3
            self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float())  # n x F+3
    
        def _build_C(self, F):
            """ I_r의 fiducial points의 좌표를 반환합니다. """
            ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
            ctrl_pts_y_top = -1 * np
    


  • Feature Extraction(Feat.): ResNet
    코드 접기/펼치기
    import torch.nn as nn
    import torch.nn.functional as F
    
    class ResNet_FeatureExtractor(nn.Module):
    
        def __init__(self, input_channel, output_channel=512):
            super(ResNet_FeatureExtractor, self).__init__()
            self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])
    
        def forward(self, input):
            return self.ConvNet(input)
    
    class BasicBlock(nn.Module):
        expansion = 1
    
        def __init__(self, inplanes, planes, stride=1, downsample=None):
            super(BasicBlock, self).__init__()
            self.conv1 = self._conv3x3(inplanes, planes, stride)  # 첫 번째 3x3 합성곱 레이어
            self.bn1 = nn.BatchNorm2d(planes)  # 배치 정규화
            self.conv2 = self._conv3x3(planes, planes)  # 두 번째 3x3 합성곱 레이어
            self.bn2 = nn.BatchNorm2d(planes)  # 배치 정규화
            self.relu = nn.ReLU(inplace=True)  # ReLU 활성화 함수
            self.downsample = downsample  # 다운 샘플링 레이어
            self.stride = stride
    
        def _conv3x3(self, in_planes, out_planes, stride=1):
            "3x3 컨볼루션 레이어"
            return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                            padding=1, bias=False)
    
        def forward(self, x):
            residual = x
    
            out = self.conv1(x)  # 첫 번째 컨볼루션 레이어 적용
            out = self.bn1(out)  # 배치 정규화
            out = self.relu(out)  # ReLU 활성화 함수
    
            out = self.conv2(out)  # 두 번째 컨볼루션 레이어 적용
            out = self.bn2(out)  # 배치 정규화
    
            if self.downsample is not None:
                residual = self.downsample(x)  # 다운 샘플링 레이어 적용
            out += residual  # 잔차 연결
            out = self.relu(out)  # ReLU 활성화 함수
    
            return out
    
    class ResNet(nn.Module):
    
        def __init__(self, input_channel, output_channel, block, layers):
            super(ResNet, self).__init__()
    
            # 각 레이어에서 출력 채널의 수를 정의
            self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
    
            # 초기 입력 채널 수 설정
            self.inplanes = int(output_channel / 8)
    
            # 초기 두 개의 컨볼루션 레이어 정의
            self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
                                    kernel_size=3, stride=1, padding=1, bias=False)
            self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
            self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
                                    kernel_size=3, stride=1, padding=1, bias=False)
            self.bn0_2 = nn.BatchNorm2d(self.inplanes)
            self.relu = nn.ReLU(inplace=True)
    
            # 각 레이어에 대한 맥스 풀링 레이어 정의
            self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
            self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
            self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
    
            # ResNet 레이어들을 생성하는 함수 호출
            self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
            self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
            self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
            self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
    
            # 각 레이어에 대한 컨볼루션 레이어 정의
            self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
                                  0], kernel_size=3, stride=1, padding=1, bias=False)
            self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
                                  1], kernel_size=3, stride=1, padding=1, bias=False)
            self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
                                  2], kernel_size=3, stride=1, padding=1, bias=False)
            self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
                                    3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
            self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
                                    3], kernel_size=1, stride=1, padding=0, bias=False)
    
            # 각 레이어에 대한 배치 정규화 레이어 정의
            self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
            self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
            self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
            self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
            self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
    
        # 각 ResNet 레이어를 생성하는 함수
        def _make_layer(self, block, planes, blocks, stride=1):
            downsample = None
            if stride != 1 or self.inplanes != planes * block.expansion:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
                )
    
            layers = []
            layers.append(block(self.inplanes, planes, stride, downsample))
            self.inplanes = planes * block.expansion
            for i in range(1, blocks):
                layers.append(block(self.inplanes, planes))
    
            return nn.Sequential(*layers)
    
        # 순전파 함수 정의
        def forward(self, x):
            x = self.conv0_1(x)
            x = self.bn0_1(x)
            x = self.relu(x)
            x = self.conv0_2(x)
            x = self.bn0_2(x)
            x = self.relu(x)
    
            x = self.maxpool1(x)
            x = self.layer1(x)
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
    
            x = self.maxpool2(x)
            x = self.layer2(x)
            x = self.conv2(x)
            x = self.bn2(x)
            x = self.relu(x)
    
            x = self.maxpool3(x)
            x = self.layer3(x)
            x = self.conv3(x)
            x = self.bn3(x)
            x = self.relu(x)
    
            x = self.layer4(x)
            x = self.conv4_1(x)
            x = self.bn4_1(x)
            x = self.relu(x)
            x = self.conv4_2(x)
            x = self.bn4_2(x)
            x = self.relu(x)
    
            return x
    


  • Sequence Modeling (Seq.): None
    ▶ Sequence Modeling은 수행하지 않으므로 패스

  • Prediction (Pred.): CTC
    코드 접기/펼치기
    import torch
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    class CTCLabelConverter(object):
        """ 텍스트 라벨과 텍스트 인덱스 간의 변환을 담당하는 클래스 """
    
        def __init__(self, character):
            # character (str): 가능한 문자 집합.
            dict_character = list(character)
    
            self.dict = {}
            for i, char in enumerate(dict_character):
                # 참고: CTCLoss에서 필요한 'CTCblank' 토큰을 위해 0을 예약합니다.
                self.dict[char] = i + 1
    
            self.character = ['[CTCblank]'] + dict_character  # CTCLoss를 위한 더미 '[CTCblank]' 토큰 (인덱스 0)
    
        def encode(self, text, batch_max_length=25):
            """ 텍스트 라벨을 텍스트 인덱스로 변환합니다.
            입력:
                text: 각 이미지의 텍스트 라벨. [batch_size]
                batch_max_length: 배치에서 텍스트 라벨의 최대 길이. 기본값은 25
    
            출력:
                text: CTCLoss를 위한 텍스트 인덱스. [batch_size, batch_max_length]
                length: 각 텍스트의 길이. [batch_size]
            """
            length = [len(s) for s in text]
    
            # 0으로 패딩된 인덱스는 CTC 손실 계산에 영향을 미치지 않습니다.
            batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
            for i, t in enumerate(text):
                text = list(t)
                text = [self.dict[char] for char in text]
                batch_text[i][:len(text)] = torch.LongTensor(text)
            return (batch_text.to(device), torch.IntTensor(length).to(device))
    
        def decode(self, text_index, length):
            """ 텍스트 인덱스를 텍스트 라벨로 변환합니다. """
            texts = []
            for index, l in enumerate(length):
                t = text_index[index, :]
    
                char_list = []
                for i in range(l):
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):  # 반복된 문자와 블랭크를 제거합니다.
                        char_list.append(self.character[t[i]])
                text = ''.join(char_list)
    
                texts.append(text)
            return texts
    


모델 평가 결과

아래 결과 그림은 다섯 가지 텍스트 인식 모델들의 정밀도(Precision)와 재현율(Recall)을 비교하고 있습니다.

“Tesseract”는 높은 재현율을 보이지만 정확도가 낮고, “Deep Text Recognition”은 정확도가 중간 수준이지만 높은 재현율을 가지고 있습니다. 이러한 결과를 통해 각 모델의 장단점을 파악할 수 있습니다.

결론적으로 제안하는 TPS-ResNet-None-CTC 조합의 프레임워크 모델(“Drawing Text Recognition”)이 다른 딥러닝 모델들 보다 뛰어난 설계 도면 내 텍스트 인식 성능을 보이므로 설계 도면 텍스트 인식에 최적화되었다고 판단됩니다.

Result
[ 딥러닝 모델별 도면 내 텍스트 인식 결과 ]
  • Matplotlib 코드
    코드 접기/펼치기
    import matplotlib.pyplot as plt
    from matplotlib.font_manager import FontProperties
    
    # 데이터
    models = ['Tesseract', 'Rosetta', 'Star-Net', 'Deep Text Recognition', '(Our) Drawing Text Recognition']
    precision = [61.6, 89.76, 87.46, 82.94, 92.41]
    recall = [99.26, 99.31, 99.18, 99.2, 99.32]
    
    # 분산형 차트 그리기
    plt.figure(figsize=(10, 6))
    
    # 각 모델에 대한 특정 마커 및 색상 사용하여 점 그리기
    plt.scatter(recall[0], precision[0], color='gray', marker='s', label='Tesseract')  # 회색 정사각형
    plt.scatter(recall[1], precision[1], color='red', marker='^', label='Rosetta')  # 빨간색 삼각형
    plt.scatter(recall[2], precision[2], color='#f2bb05', marker='*', label='Star-Net')  # 노란색 별모양
    plt.scatter(recall[3], precision[3], color='green', marker='D', label='Deep Text Recognition')  # 초록색 다이아몬드
    plt.scatter(recall[4], precision[4], color='blue', marker='o', label='(Our) Drawing Text Recognition')  # 파란색 동그라미
    
    
    # 대각선 가이드선 추가
    plt.plot([99.15, 99.35], [60, 100], color='#707070', linestyle='--', linewidth=0.5)
    
    plt.xlabel('Recall(%)')
    plt.ylabel('Precision(%)')
    plt.title('Results of Text Recognition in Drawings by Deep Learning Models')
    
    # X 축과 Y 축의 범위 설정
    plt.xlim(99.15, 99.35)
    plt.ylim(60.0, 100.0)
    
    # 범례 추가 및 굵은 글꼴 설정
    legend = plt.legend()
    font = FontProperties(weight='bold')  # 굵은 글꼴 설정
    for text in legend.texts:
        if text.get_text() == '(Our) Drawing Text Recognition':
            text.set_font_properties(font)  # 굵은 글꼴 적용
    
    plt.grid(True)
    plt.show()
    


4. 결론

이번 포스트에서는 설계 도면 내 텍스트 영역을 탐지하고 인식하는 기술에 대해 다뤘습니다. 텍스트 인식 모델을 구축하기 위해 다양한 건설 프로젝트의 도면 데이터를 수집하고 전처리하는 과정에서 어려움을 겪었지만, 이러한 다양성을 고려한 텍스트 인식 모델의 설계와 평가를 통해 뛰어난 결과를 얻었습니다.

우리는 “Transformation(Trans.)”, “Feature Extraction(Feat.)”, “Sequence Modeling(Seq.)”, 그리고 “Prediction(Pred.)” 이렇게 네 가지 핵심 스테이지를 통해 텍스트 인식 모델을 분석하고 설계하였습니다. 특히 Transformation 단계에서는 Thin-Plate Spline(TPS) 변환을 사용하여 복잡한 텍스트 형태를 보정하고, Feature Extraction 단계에서는 ResNet 아키텍처를 활용하여 텍스트 이미지의 정보를 효과적으로 분류하였습니다. Sequence Modeling 단계에서는 문맥 정보를 포함하여 더 나은 시퀀스를 형성하고, Prediction 단계에서는 Connectionist Temporal Classification(CTC)와 Attention-based Sequence Prediction(Attn)을 비교하여 뛰어난 결과를 얻을 수 있었습니다.

포스트에서 제안한 TPS-ResNet-None-CTC 조합의 프레임워크 모델(Drawing Text Recognition)은 다른 딥러닝 모델들보다 뛰어난 설계 도면 내 텍스트 인식 성능을 보였습니다. 특히 정밀도 측면에서 우수함을 확인할 수 있었고, 이는 현장에서의 실용성과 정확성을 확보하는 데 큰 기여를 할 것으로 기대됩니다.

이러한 연구를 통해 설계 도면 내 텍스트를 인식하는 기술의 혁신과 발전에 일조할 수 있을 것이며, 미래에는 더욱 정확하고 빠른 텍스트 인식 기술이 산업 현장에서 활용될 것임을 기대합니다.

맨 위로 이동 ↑

댓글남기기