๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ

์•Œ๊ณ ๋ฆฌ์ฆ˜/LEETCODE

Remove Invalid Parentheses | LeetCode 301 | Python3 ๐Ÿ

๋ฐ˜์‘ํ˜•

 

๐Ÿ“„ ๋ชฉ์ฐจ

     

     

    ์ œ๋ชฉ:  https://leetcode.com/problems/remove-invalid-parentheses/๐Ÿ

    ๐Ÿค” ๋ฌธ์ œ : Remove Invalid Parentheses | LeetCode 301

    ๋ฌธ์ œ: https://leetcode.com/problems/remove-invalid-parentheses/

    ์ฃผ์–ด์ง„ ๋ฌธ์ž์—ด์—์„œ ๊ฐ€์žฅ ๊ธด palindrome์„ ์ฐพ๋Š” ๋ฌธ์ œ์ž…๋‹ˆ๋‹ค.

     

     

    ๐Ÿ’ก ํ’€์ด

    1. ๊ตฌ๊ตฌ์ ˆ์ ˆ.. iteration

    ์ฃผ์–ด์ง„ s์˜ ๋ชจ๋“  substring์— ๋Œ€ํ•ด์„œ, valid parenthesis ์—ฌ๋ถ€๋ฅผ ํ™•์ธํ•˜๋ฉด์„œ, ๊ฐ€์žฅ ๊ธด ๊ฒƒ๋“ค์„ ์ €์žฅํ•ด์„œ ๋ฐ˜ํ™˜ํ•ด์•ผ ๊ฒ ๋‹ค.

    ๊ทธ๋Ÿผ TLE ๊ฐ€ ๋‚  ๊ฒƒ ๊ฐ™์œผ๋‹ˆ, (์™€ )์˜ ๊ฐฏ์ˆ˜๋ฅผ ๋Œ€์กฐํ•ด์„œ ๋” ๋งŽ์€๊ฑธ ๋งŽ์€ ๊ฐฏ์ˆ˜๋งŒํผ ์‚ญ์ œํ•œ substring๋จผ์ € ํ™•์ธํ•˜๊ณ .. ๊ฑฐ๊ธฐ์„œ ์—†์œผ๋ฉด (์™€ )๋ฅผ ํ•˜๋‚˜์”ฉ ๋” ์‚ญ์ œํ•˜๊ณ ..

     

    ๋“ฑ๋“ฑ์˜ ๋กœ์ง์„ ๋‹ด์€ ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•˜์˜€๊ณ  ์–ด์ฐŒ์–ด์ฐŒ ํ†ต๊ณผ๋Š” ๋˜์—ˆ๋„ค์š”.

    ํ•˜์ง€๋งŒ ๋‹ค์‹œ๋Š” ์ฝ๊ณ ์‹ถ์ง€ ์•Š์€ ๊ธธ์ด์˜ ์ฝ”๋“œ๊ฐ€ ๋˜์—ˆ์œผ๋‹ˆ.. ๋ญ”๊ฐ€ ๋‹จ๋‹จํžˆ ์ž˜๋ชป๋œ ๊ฒƒ ๊ฐ™์ฃ ..

    from typing import List
    
    # s = "()())()"
    # => ["(())()","()()()"]
    class Solution:
        def removeInvalidParentheses(self, s: str) -> List[str]:
            # ์ œ๊ฑฐํ•˜๋Š” ๋ชจ๋“  ์กฐํ•ฉ์„ ๋งŒ๋“ค์–ด์„œ valid ์ฒดํฌ
    
            # valid ์—ฌ๋ถ€ ํ™•์ธ, ๋ฌธ์ž์—ด์—์„œ ํฌํ•จ ์•ˆํ•  idx ์ง€์ • - O(N)
            def isParenthesisSkip(s: str, skipOrg: List[int]) -> bool:
                skip = skipOrg[:]
    
                # ๊ฑด๋„ˆ๋›ธ๊ฒŒ ๋ฌธ์ž๋ฉด pass
                chars_to_skip = [ (s[idx] == '(' or s[idx] == ')') for idx in skip]
                if not all(chars_to_skip):
                    return False
    
                # ๊ฑด๋„ˆ๋›ธ๊ฒŒ ๋ฌธ์ž๋Š” ํฌํ•จํ•˜์ง€ ์•Š๊ณ  parenthesis check
                stack = []
                toskip = skip.pop(0) if len(skip) > 0 else -1
    
                for idx, c in enumerate(s):
                    if idx == toskip:
                        toskip = skip.pop(0) if len(skip) > 0 else -1
                        continue
    
                    # print('stack: ', stack, 'skip idx: ', toskip)
                    if c == '(':
                        stack.append(c)
                    elif c == ')':
                        if len(stack) == 0 or stack[-1] != '(':
                            return False
                        stack.pop()
                if len(stack) == 0:
                    return True
                return False
            
            def trim(s: str) -> str:
                # ๋ฌธ์ž์—ด ์•ž๋ถ€๋ถ„์˜ ), ๋’ท๋ถ€๋ถ„์˜ ( ์‚ญ์ œ
                idx_front = -1
                idx_back = -1
                for idx, c in enumerate(s):
                    if c != ')':
                        idx_front = idx
                        break
                
                for idx, c in enumerate(reversed(s)):
                    if c != '(':
                        idx_back = idx
                        break
                return s[idx_front: len(s)-idx_back]
    
            def powerset(s):
                x = len(s)
                masks = [1 << i for i in range(x)]
                for i in range(1 << x):
                    yield [ss for mask, ss in zip(masks, s) if i & mask]
    
            def make_string(s, idx_to_remove_org):
                # array deep copy
                idx_to_remove = idx_to_remove_org[:]
    
                if len(idx_to_remove) == 0:
                    return s
                
                ret = ""
                idx_to_remove_next = idx_to_remove.pop(0)
                for idx, char in enumerate(s):
                    if idx == idx_to_remove_next:
                        idx_to_remove_next = idx_to_remove.pop(0) if len(idx_to_remove) > 0 else -1
                    else:
                        ret += char
                return ret
    
            trimed_string = trim(s)
    
            # ( ์™€ )์˜ ์œ„์น˜๋“ค์„ ์ €์žฅ
            position_op = []
            position_cp = []
            for idx, c in enumerate(trimed_string):
                if c =='(':
                    position_op.append(idx)
                elif c == ')':
                    position_cp.append(idx)
            
            # ( ์™€ )์˜ ๊ฐฏ์ˆ˜๋ฅผ ๋น„๊ตํ•˜์—ฌ ์ฐจ์ด๋ฅผ ๊ณ„์‚ฐ
            diff = len(position_op) - len(position_cp)
            skip_num_op = max(diff, 0)
            skip_num_cp = max(-diff, 0)
            
            # possible skip options
            postion_list_op = list(powerset(position_op))
            postion_list_cp = list(powerset(position_cp))
    
            result_found = False
            result = []
            while skip_num_op <= len(position_op) and skip_num_cp <= len(position_cp) and not result_found:
                op_skip_list = list(filter(lambda x: len(x) == skip_num_op, postion_list_op))
                cp_skip_list = list(filter(lambda x: len(x) == skip_num_cp, postion_list_cp))
    
                if len(op_skip_list) == 0:
                    for cp_skip in cp_skip_list:
                        skip_options =sorted(cp_skip)
    
                        if isParenthesisSkip(trimed_string, skip_options):
                            result_found = True
                            result.append(make_string(trimed_string, skip_options))
                elif len(cp_skip_list) == 0:
                    for op_skip in op_skip_list:
                        skip_options =sorted(op_skip)
    
                        if isParenthesisSkip(trimed_string, skip_options):
                            result_found = True
                            result.append(make_string(trimed_string, skip_options))
    
                else:
                    for op_skip in op_skip_list:
                        for cp_skip in cp_skip_list:
                            skip_options =sorted(op_skip+cp_skip)
    
                            if isParenthesisSkip(trimed_string, skip_options):
                                result_found = True
                                result.append(make_string(trimed_string, skip_options))
    
                skip_num_op += 1
                skip_num_cp += 1
    
            return list(set(result))

     

    2. DFS

    ์–ด์ฉ์ง€ ์ฝ”๋“œ๊ฐ€ ๊ตฌ๊ตฌ์ ˆ์ ˆํ•˜๋‹ค ์‹ถ์œผ๋ฉด..  DFS๋‚˜ BFS๋กœ ํ’€์–ด์•ผ ํ•  ๊ฑธ iteration์œผ๋กœ ํ’€์–ด์•ผ ํ•œ ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์•˜์Šต๋‹ˆ๋‹ค. 

    ์ด๋ฒˆ์—๋„ ์—ญ์‹œ๋‚˜ discussion์„ ํ™•์ธํ•ด๋ณด๋‹ˆ DFSํ’€์ด๊ฐ€ ๋งŽ์•˜์Šต๋‹ˆ๋‹ค.

    DFS๋กœ ๋‹ค์‹œ ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•ด๋ณด์•˜์Šต๋‹ˆ๋‹ค.

    from typing import List
    
    # s = "()())()"
    # => ["(())()","()()()"]
    class Solution:
        def removeInvalidParentheses(self, s: str) -> List[str]:
            self.longest_str_size = -1
            self.result = []
    
            self.isValidParenthesis(s, len(s), '', 0, 0, 0)
            return list(set(self.result)) 
    
        def isValidParenthesis(self, s: str, lens: int, currs: str, idx: int, num_o, num_c):
                if num_c > num_o:
                    return False
    
                # ๋งˆ์ง€๋ง‰๊นŒ์ง€ ๋‹ค ๋ดค์Œ
                if lens == idx:
                    # print('isValidParenthesis', idx, currs, num_c, num_o, self.longest_str_size)
    
                    if num_c != num_o:
                        return False
                    else:
                        if len(currs) > self.longest_str_size:
                            self.longest_str_size = len(currs)
                            self.result = [currs]
                            return True
                        elif len(currs) == self.longest_str_size:
                            self.result.append(currs)
                            return True
                        else:
                            return False
    
                self.isValidParenthesis(s, lens, currs+s[idx], idx+1, num_o + (1 if s[idx] == "(" else 0), num_c+ (1 if s[idx] == ")" else 0))
                self.isValidParenthesis(s, lens, currs, idx+1, num_o, num_c)

     

    ํ•œ๊ฒฐ ๊น”๋”ํ•ด์กŒ์ง€๋งŒ ๊ฒฐ๊ณผ๋Š” TLE. 

    ๋ฌธ์ž์—ด์˜ ๊ธธ์ด๊ฐ€ N์ด๋ผ๊ณ  ํ• ๋•Œ DFS๋Š” O(2^N) ์ด์ฃ .

     

     

    3.  DFS + pruning

    DFS์—์„œ ๋ถˆ๊ฐ€๋Šฅํ•œ ๊ฐ€์ง€๋“ค์„ ์‚ฌ์ „์— pruningํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ์ตœ์ข… ๋‹ต์„ ์ œ์ถœํ•˜์˜€์Šต๋‹ˆ๋‹ค.

     

    ์•„๋ž˜ ์ฝ”๋“œ์˜ [2], [4]์— ๋Œ€ํ•œ ๋ถ€์—ฐ์„ค๋ช…์ž…๋‹ˆ๋‹ค.

     

    [2] skip๊ฐฏ์ˆ˜ ๊ธฐ๋ฐ˜ pruning

    valid parentheses ์ค‘ ๊ฐ€์žฅ ๊ธด ๊ฒƒ์„ ์ฐพ๋Š” ๋ฌธ์ œ์ž…๋‹ˆ๋‹ค. 

    ๋งŒ์•ฝ ์ด๋ฏธ ์ฐพ์€ valid parentheses๊ฐ€ ์‚ญ์ œํ•œ ๋ฌธ์ž๋Š” 1๊ฐœ์ธ๋ฐ ์ง€๊ธˆ ํ™•์ธํ•˜๋Š” ๋ฌธ์ž์—ด์€ ์ด๋ฏธ 2๊ฐœ๋ฅผ ์‚ญ์ œํ–ˆ๋‹ค๋ฉด, ์ด ๋ฌธ์ž์—ด์ด validํ•œ๋“ค ๊ฐ€์žฅ ๊ธด valid parentheses๋Š” ๋  ์ˆ˜ ์—†์œผ๋ฏ€๋กœ ์ •๋‹ต์ด ์•„๋‹™๋‹ˆ๋‹ค. 

     

    [3] valid parentheses๊ฐ€ ๋  ๊ฐ€๋Šฅ์„ฑ์ด ์žˆ๋Š”์ง€ ํ™•์ธ

    valid parentheses๋Š” ( ์™€ )์˜ ๊ฐฏ์ˆ˜๊ฐ€ ๊ฐ™๋‹ค๋Š” ํŠน์ง•์ด ์žˆ์Šต๋‹ˆ๋‹ค.

    ๋งŒ์•ฝ ํ˜„์žฌ ๋ฌธ์ž์—ด์— (๊ฐ€ 10๊ฐœ์ธ๋ฐ ๋‚จ์€ ) ๋ฅผ ๋‹ค ์“ด๋‹ค๊ณ  ํ•ด๋„ 5๊ฐœ๋ฐ–์— ์•ˆ๋œ๋‹ค๋ฉด ์ด ๋ฌธ์ž์—ด์€ valid parentheses๊ฐ€ ๋  ๊ฐ€๋Šฅ์„ฑ์ด ์—†์Šต๋‹ˆ๋‹ค. 

     

    ')'์˜ ๊ฐฏ์ˆ˜๊ฐ€ '('๋ณด๋‹ค ๋งŽ์€ ์‹œ์ ์ด ํ•œ๋ฒˆ์ด๋ผ๋„ ์žˆ๋‹ค๋ฉด ์ด๋ฏธ [1]์—์„œ invad parentheses๋กœ False๋ฅผ ๋ฆฌํ„ดํ•  ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์— ํ•ด๋‹น ์ผ€์ด์Šค๋Š” ๊ณ ๋ คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

     

    ๋Š” (์™€ )์˜ ๊ฐฏ์ˆ˜๊ฐ€ ๊ฐ™๋‹ค๋Š” ํŠน์ง•์ด ์žˆ์Šต๋‹ˆ๋‹ค. 

    from typing import List
    
    class Solution:
        def removeInvalidParentheses(self, s: str) -> List[str]:
            self.longest_str_size = -1
            self.result = []
    
            self.dfs(s, len(s), '', 0, 0, 0, 0, s.count('('), s.count(')'))
            return list(set(self.result))
    
        def dfs(self, s: str, lens: int, currs: str, idx: int, num_o, num_c, num_skip, remaining_o, remaining_c):
                # [1] valid parenthesis์ธ์ง€ ํ™•์ธ
                if num_c > num_o:
                    return False
                
                # [2] skip๊ฐฏ์ˆ˜ ๊ธฐ๋ฐ˜ pruning
                if num_skip > lens-self.longest_str_size:
                    return False
                
                # [3] valid parentheses๊ฐ€ ๋  ๊ฐ€๋Šฅ์„ฑ์ด ์žˆ๋Š”์ง€ ํ™•์ธ
                if num_o > num_c + remaining_c:
                    return False
    
                # ๋งˆ์ง€๋ง‰ ์ธ๋ฑ์Šค์ธ ๊ฒฝ์šฐ validity check
                if lens == idx:
                    # print('isValidParenthesis', idx, currs, num_c, num_o, self.longest_str_size)
    
                    if num_c != num_o:
                        return False
                    else:
                        if len(currs) > self.longest_str_size:
                            self.longest_str_size = len(currs)
                            self.result = [currs]
                            return True
                        elif len(currs) == self.longest_str_size:
                            self.result.append(currs)
                            return True
                        else:
                            return False
    
                self.dfs(s, lens, currs+s[idx], idx+1, num_o + (1 if s[idx] == "(" else 0), num_c+ (1 if s[idx] == ")" else 0), num_skip, remaining_o-(1 if s[idx] == "(" else 0), remaining_c-(1 if s[idx] == ")" else 0))
                self.dfs(s, lens, currs, idx+1, num_o, num_c, num_skip+1, remaining_o-(1 if s[idx] == "(" else 0), remaining_c-(1 if s[idx] == ")" else 0))

     

    ๋ฐ˜์‘ํ˜•