본문 바로가기
Algorithm/백준

[Java/Python] 백준 14426번 - 접두사 찾기 (실버 1)

by 애기 개발자 2025. 6. 5.
반응형
혼자 힘으로 풀었는가? X

알고리즘 분류
 - 자료 구조
 - 문자열
 - 트리
 - 이분 탐색
 - 트라이

 

문제

문자열 S의 접두사란 S의 가장 앞에서부터 부분 문자열을 의미한다. 예를 들어, S = "codeplus"의 접두사는 "code", "co", "codepl", "codeplus"가 있고, "plus", "s", "cude", "crud"는 접두사가 아니다.

총 N개의 문자열로 이루어진 집합 S가 주어진다.

입력으로 주어지는 M개의 문자열 중에서 집합 S에 포함되어 있는 문자열 중 적어도 하나의 접두사인 것의 개수를 구하는 프로그램을 작성하시오.

입력

첫째 줄에 문자열의 개수 N과 M (1 ≤ N ≤ 10,000, 1 ≤ M ≤ 10,000)이 주어진다.

다음 N개의 줄에는 집합 S에 포함되어 있는 문자열이 주어진다.

다음 M개의 줄에는 검사해야 하는 문자열이 주어진다.

입력으로 주어지는 문자열은 알파벳 소문자로만 이루어져 있으며, 길이는 500을 넘지 않는다. 집합 S에 같은 문자열이 여러 번 주어지는 경우는 없다.

출력

첫째 줄에 M개의 문자열 중에 총 몇 개가 포함되어 있는 문자열 중 적어도 하나의 접두사인지 출력한다.


 

처음엔 단순히 주어진 n개의 문자열 배열과 m개의 문자열 배열을 주어진 순서대로 하나하나 잘라서 비교했다.

 

당연히 시간 초과가 발생했다. (n, m은 10,000 이하 -> 최악의 경우 10,000 * 10,000)

 

알고리즘 분류를 보는데 이분탐색과 트라이 라는 알고리즘이 사용되었다는데 트라이는 내가 모르고 이분탐색으로 해결해보고 싶엇는데 방법을 도무지 떠올리지 못해서 검색해봤다.

 

이 문제를 해결하는데 아주 간단한 방법이 있었다.

 

  1. 주어진 문자열 배열을 정렬한다.
  2. 각 문자열 배열을 순차적으로 비교하며 비교 횟수를 최소화 한다.

이 방법은 내가 맨처음 시도한 방법과는 속도차이가 현저히 발생하게 된다.

 

맨 처음방법은 단순히 n * m 의 비교를 하지만

두 번째 방법은 문자열의 정렬된 순서로 비교하고, 비교기준이 되는 문자열보다 정렬순서가 낮은 문자열은 비교하지 않기 때문에 비교횟수가 줄어드는 것이다.

 

Java코드

import java.io.*;
import java.util.*;
public class Main {
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());
int m = Integer.parseInt(st.nextToken());
String [] set = new String[n];
String [] check = new String[m];
for(int i=0; i<n; i++) {
set[i] = br.readLine();
}
for(int i=0; i<m; i++) {
check[i] = br.readLine();
}
Arrays.sort(set);
Arrays.sort(check);
int i=0;
int j=0;
int cnt = 0;
while(i < n && j < m) {
String a = set[i].substring(0, check[j].length());
String b = check[j];
if(a.equals(b)) {
cnt++;
j++;
} else if(set[i].compareTo(check[j]) > 0) {
j++;
} else {
i++;
}
}
System.out.println(cnt);
}
}

 

 

Python 코드

import sys
input = sys.stdin.readline
n, m = map(int, input().split())
s_set = []
s_check = []
for i in range(n):
s_set.append(input().strip())
s_set.sort()
for i in range(m):
s_check.append(input().strip())
s_check.sort()
i, j = 0, 0
cnt = 0
while i < n and j < m:
if s_set[i][:len(s_check[j])] == s_check[j]:
cnt += 1
j += 1
elif s_set[i] > s_check[j]:
j += 1
else:
i += 1
print(cnt)
반응형

댓글