Skip to content

Extract subclass

Introduction

Extract subclass refactoring

ExtractSubClassRefactoringListener (JavaParserLabeledListener)

To implement extract subclass refactoring based on its actors.

Creates a new class and move fields and methods from the old class to the new one

Source code in codart\refactorings\extract_subclass.py
class ExtractSubClassRefactoringListener(JavaParserLabeledListener):
    """

    To implement extract subclass refactoring based on its actors.

    Creates a new class and move fields and methods from the old class to the new one

    """

    def __init__(
            self, common_token_stream: CommonTokenStream = None,
            source_class: str = None, new_class: str = None,
            moved_fields=None, moved_methods=None,
            output_path: str = ""):

        if moved_methods is None:
            self.moved_methods = []
        else:
            self.moved_methods = moved_methods
        if moved_fields is None:
            self.moved_fields = []
        else:
            self.moved_fields = moved_fields

        if common_token_stream is None:
            raise ValueError('common_token_stream is None')
        else:
            self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)

        if source_class is None:
            raise ValueError("source_class is None")
        else:
            self.source_class = source_class
        if new_class is None:
            raise ValueError("new_class is None")
        else:
            self.new_class = new_class

        self.output_path = output_path

        self.is_source_class = False
        self.detected_field = None
        self.detected_method = None
        self.TAB = "\t"
        self.NEW_LINE = "\n"
        self.code = ""
        self.is_in_constructor = False

    def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        """

        It checks if it is source class, we generate the declaration of the new class,
        by appending some text to self.code.

        """

        class_identifier = ctx.IDENTIFIER().getText()
        if class_identifier == self.source_class:
            self.is_source_class = True
            self.code += self.NEW_LINE * 2
            self.code += f"// New class({self.new_class}) generated by CodART" + self.NEW_LINE
            self.code += f"class {self.new_class} extends {self.source_class}{self.NEW_LINE}" + "{" + self.NEW_LINE
            self.code += f"public {self.new_class}()" + "{ }" + self.NEW_LINE
        else:
            self.is_source_class = False

    def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        """

        It closes the opened curly brackets If it is the source class.

        """

        if self.is_source_class:
            self.code += "}"
            self.is_source_class = False

    def exitCompilationUnit(self, ctx: JavaParserLabeled.CompilationUnitContext):
        """

        It writes self.code in the output path.

        """

        child_file_name = self.new_class + ".java"
        with open(os.path.join(self.output_path, child_file_name), "w+") as f:
            f.write(self.code.replace('\r\n', '\n'))

    def enterVariableDeclaratorId(self, ctx: JavaParserLabeled.VariableDeclaratorIdContext):
        """

        It sets the detected field to the field if it is one of the moved fields.

        """

        if not self.is_source_class:
            return None
        field_identifier = ctx.IDENTIFIER().getText()
        if field_identifier in self.moved_fields:
            self.detected_field = field_identifier

    def exitFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext):
        """

        It gets the field name, if the field is one of the moved fields,
        we move it and delete it from the source program.

        """

        if not self.is_source_class:
            return None
        field_identifier = ctx.variableDeclarators().variableDeclarator(0).variableDeclaratorId().IDENTIFIER().getText()
        field_names = list()
        field_names.append(field_identifier)
        # print("field_names=", field_names)
        grand_parent_ctx = ctx.parentCtx.parentCtx
        if self.detected_field in field_names:
            if not grand_parent_ctx.modifier():
                modifier = ""
            else:
                modifier = grand_parent_ctx.modifier(0).getText()
            field_type = ctx.typeType().getText()
            self.code += f"{self.TAB}{modifier} {field_type} {self.detected_field};{self.NEW_LINE}"

            # delete field from source class ==>new
            start_index = ctx.parentCtx.parentCtx.start.tokenIndex
            stop_index = ctx.parentCtx.parentCtx.stop.tokenIndex
            self.token_stream_rewriter.delete(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                from_idx=start_index,
                to_idx=stop_index
            )

            self.detected_field = None

    def enterMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
        """

        It sets the detected field to the method if it is one of the moved methods.

        """

        if not self.is_source_class:
            return None
        method_identifier = ctx.IDENTIFIER().getText()
        if method_identifier in self.moved_methods:
            self.detected_method = method_identifier

    def exitMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
        """

        It gets the method name, if the method is one of the moved methods,
        we move it to the subclass and delete it from the source program.

        """

        if not self.is_source_class:
            return None
        method_identifier = ctx.IDENTIFIER().getText()
        if self.detected_method == method_identifier:
            start_index = ctx.parentCtx.parentCtx.start.tokenIndex
            stop_index = ctx.stop.tokenIndex
            method_text = self.token_stream_rewriter.getText(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                start=start_index,
                stop=stop_index
            )
            self.code += (self.NEW_LINE + self.TAB + method_text + self.NEW_LINE)
            # delete method from source class
            self.token_stream_rewriter.delete(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                from_idx=start_index,
                to_idx=stop_index
            )
            self.detected_method = None

    def enterConstructorDeclaration(self, ctx: JavaParserLabeled.ConstructorDeclarationContext):
        if self.is_source_class:
            self.is_in_constructor = True
            self.fields_in_constructor = []
            self.methods_in_constructor = []
            self.constructor_body = ctx.block()
            children = self.constructor_body.children

    def exitConstructorDeclaration(self, ctx: JavaParserLabeled.ConstructorDeclarationContext):
        if self.is_source_class and self.is_in_constructor:
            move_constructor_flag = False
            for field in self.fields_in_constructor:
                if field in self.moved_fields:
                    move_constructor_flag = True

            for method in self.methods_in_constructor:
                if method in self.moved_methods:
                    move_constructor_flag = True

            if move_constructor_flag:
                if ctx.formalParameters().formalParameterList():
                    constructor_parameters = [ctx.formalParameters().formalParameterList().children[i] for i in
                                              range(len(ctx.formalParameters().formalParameterList().children)) if
                                              i % 2 == 0]
                else:
                    constructor_parameters = []

                constructor_text = ''
                for modifier in ctx.parentCtx.parentCtx.modifier():
                    constructor_text += modifier.getText() + ' '
                constructor_text += self.new_class
                constructor_text += ' ( '
                for parameter in constructor_parameters:
                    constructor_text += parameter.typeType().getText() + ' '
                    constructor_text += parameter.variableDeclaratorId().getText() + ', '
                if constructor_parameters:
                    constructor_text = constructor_text[:len(constructor_text) - 2]
                constructor_text += ')\n\t{'
                constructor_text += self.token_stream_rewriter.getText(
                    program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                    start=ctx.block().start.tokenIndex + 1,
                    stop=ctx.block().stop.tokenIndex - 1
                )
                constructor_text += '}\n'
                self.code += constructor_text
                start_index = ctx.parentCtx.parentCtx.start.tokenIndex
                stop_index = ctx.parentCtx.parentCtx.stop.tokenIndex
                self.token_stream_rewriter.delete(
                    program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                    from_idx=start_index,
                    to_idx=stop_index
                )

        self.is_in_constructor = False

    def enterExpression21(self, ctx: JavaParserLabeled.Expression21Context):
        if self.is_source_class and self.is_in_constructor:
            if len(ctx.children[0].children) == 1:
                self.fields_in_constructor.append(ctx.children[0].getText())
            else:
                self.fields_in_constructor.append(ctx.children[0].children[-1].getText())

    def enterMethodCall0(self, ctx: JavaParserLabeled.MethodCall0Context):
        if self.is_source_class and self.is_in_constructor:
            self.methods_in_constructor.append(ctx.IDENTIFIER())

enterClassDeclaration(self, ctx)

It checks if it is source class, we generate the declaration of the new class, by appending some text to self.code.

Source code in codart\refactorings\extract_subclass.py
def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
    """

    It checks if it is source class, we generate the declaration of the new class,
    by appending some text to self.code.

    """

    class_identifier = ctx.IDENTIFIER().getText()
    if class_identifier == self.source_class:
        self.is_source_class = True
        self.code += self.NEW_LINE * 2
        self.code += f"// New class({self.new_class}) generated by CodART" + self.NEW_LINE
        self.code += f"class {self.new_class} extends {self.source_class}{self.NEW_LINE}" + "{" + self.NEW_LINE
        self.code += f"public {self.new_class}()" + "{ }" + self.NEW_LINE
    else:
        self.is_source_class = False

enterMethodDeclaration(self, ctx)

It sets the detected field to the method if it is one of the moved methods.

Source code in codart\refactorings\extract_subclass.py
def enterMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
    """

    It sets the detected field to the method if it is one of the moved methods.

    """

    if not self.is_source_class:
        return None
    method_identifier = ctx.IDENTIFIER().getText()
    if method_identifier in self.moved_methods:
        self.detected_method = method_identifier

enterVariableDeclaratorId(self, ctx)

It sets the detected field to the field if it is one of the moved fields.

Source code in codart\refactorings\extract_subclass.py
def enterVariableDeclaratorId(self, ctx: JavaParserLabeled.VariableDeclaratorIdContext):
    """

    It sets the detected field to the field if it is one of the moved fields.

    """

    if not self.is_source_class:
        return None
    field_identifier = ctx.IDENTIFIER().getText()
    if field_identifier in self.moved_fields:
        self.detected_field = field_identifier

exitClassDeclaration(self, ctx)

It closes the opened curly brackets If it is the source class.

Source code in codart\refactorings\extract_subclass.py
def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
    """

    It closes the opened curly brackets If it is the source class.

    """

    if self.is_source_class:
        self.code += "}"
        self.is_source_class = False

exitCompilationUnit(self, ctx)

It writes self.code in the output path.

Source code in codart\refactorings\extract_subclass.py
def exitCompilationUnit(self, ctx: JavaParserLabeled.CompilationUnitContext):
    """

    It writes self.code in the output path.

    """

    child_file_name = self.new_class + ".java"
    with open(os.path.join(self.output_path, child_file_name), "w+") as f:
        f.write(self.code.replace('\r\n', '\n'))

exitFieldDeclaration(self, ctx)

It gets the field name, if the field is one of the moved fields, we move it and delete it from the source program.

Source code in codart\refactorings\extract_subclass.py
def exitFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext):
    """

    It gets the field name, if the field is one of the moved fields,
    we move it and delete it from the source program.

    """

    if not self.is_source_class:
        return None
    field_identifier = ctx.variableDeclarators().variableDeclarator(0).variableDeclaratorId().IDENTIFIER().getText()
    field_names = list()
    field_names.append(field_identifier)
    # print("field_names=", field_names)
    grand_parent_ctx = ctx.parentCtx.parentCtx
    if self.detected_field in field_names:
        if not grand_parent_ctx.modifier():
            modifier = ""
        else:
            modifier = grand_parent_ctx.modifier(0).getText()
        field_type = ctx.typeType().getText()
        self.code += f"{self.TAB}{modifier} {field_type} {self.detected_field};{self.NEW_LINE}"

        # delete field from source class ==>new
        start_index = ctx.parentCtx.parentCtx.start.tokenIndex
        stop_index = ctx.parentCtx.parentCtx.stop.tokenIndex
        self.token_stream_rewriter.delete(
            program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
            from_idx=start_index,
            to_idx=stop_index
        )

        self.detected_field = None

exitMethodDeclaration(self, ctx)

It gets the method name, if the method is one of the moved methods, we move it to the subclass and delete it from the source program.

Source code in codart\refactorings\extract_subclass.py
def exitMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
    """

    It gets the method name, if the method is one of the moved methods,
    we move it to the subclass and delete it from the source program.

    """

    if not self.is_source_class:
        return None
    method_identifier = ctx.IDENTIFIER().getText()
    if self.detected_method == method_identifier:
        start_index = ctx.parentCtx.parentCtx.start.tokenIndex
        stop_index = ctx.stop.tokenIndex
        method_text = self.token_stream_rewriter.getText(
            program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
            start=start_index,
            stop=stop_index
        )
        self.code += (self.NEW_LINE + self.TAB + method_text + self.NEW_LINE)
        # delete method from source class
        self.token_stream_rewriter.delete(
            program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
            from_idx=start_index,
            to_idx=stop_index
        )
        self.detected_method = None

main()

it builds the parse tree and walk its corresponding walker so that our overridden methods run.

Source code in codart\refactorings\extract_subclass.py
def main():
    """
    it builds the parse tree and walk its corresponding walker so that our overridden methods run.
    """

    # udb_path = "/home/ali/Desktop/code/TestProject/TestProject.udb"
    # udb_path=create_understand_database("C:\\Users\\asus\\Desktop\\test_project")
    # source_class = "GodClass"
    # moved_methods = ['method1', 'method3', ]
    # moved_fields = ['field1', 'field2', ]
    udb_path = "C:\\Users\\asus\\Desktop\\test_project\\test_project.udb"
    # moved_methods = ['getValue', 'rowToJSONArray', 'getVal', ]
    # moved_fields = ['number_2', 'number_1', ]

    source_class = "GodClass"
    moved_methods = ['method1', 'method3']
    moved_fields = ['field1', 'field2']
    father_path_file = "/data/Dev/JavaSample/src/GodClass.java"
    father_path_directory = "/data/Dev/JavaSample/src"
    path_to_refactor = "/data/Dev/JavaSample/src"
    new_class_file = "/data/Dev/JavaSample/src/GodSubClass.java"

    # source_class = "TaskNode"
    # moved_methods = ['getUserObject']
    # moved_fields = []
    # father_path_file = "C:\\Users\\asus\\Desktop\\benchmark_projects\\ganttproject\\ganttproject\\src\\main\\java\\net\\sourceforge\\ganttproject\\task\\TaskNode.java"
    # father_path_directory = "C:\\Users\\asus\\Desktop\\benchmark_projects\\ganttproject\\ganttproject\\src\\main\\java\\net\\sourceforge\\ganttproject\\task"
    # path_to_refactor = "C:\\Users\\asus\\Desktop\\benchmark_projects\\ganttproject"
    # new_class_file = "C:\\Users\\asus\\Desktop\\benchmark_projects\\ganttproject\\ganttproject\\src\\main\\java\\net\\sourceforge\\ganttproject\\task\\TaskNodeextracted.java"

    # source_class = "SecuritySupport"
    # moved_methods = ['getSystemProperty']
    # moved_fields = []
    # father_path_file = "C:\\Users\\asus\\Desktop\\benchmark_projects\\xerces2-j\\src\\org\\apache\\html\\dom\\SecuritySupport.java"
    # father_path_directory = "C:\\Users\\asus\\Desktop\\benchmark_projects\\xerces2-j\\src\\org\\apache\\html\\dom"
    # path_to_refactor = "C:\\Users\\asus\\Desktop\\benchmark_projects\\xerces2-j"
    # new_class_file = "C:\\Users\\asus\\Desktop\\benchmark_projects\\xerces2-j\\src\\org\\apache\\html\\dom\\SecuritySupportextracted.java"

    # source_class = "BaseMarkupSerializer"
    # moved_methods = ['setOutputCharStream']
    # moved_fields = []
    # father_path_file = "C:\\Users\\asus\\Desktop\\benchmark_projects\\xerces2-j\\src\\org\\apache\\xml\\serialize\\BaseMarkupSerializer.java"
    # father_path_directory = "C:\\Users\\asus\\Desktop\\benchmark_projects\\xerces2-j\\src\\org\\apache\\xml\\serialize"
    # path_to_refactor = "C:\\Users\\asus\\Desktop\\benchmark_projects\\xerces2-j"
    # new_class_file = "C:\\Users\\asus\\Desktop\\benchmark_projects\\xerces2-j\\src\\org\\apache\\xml\\serialize\\BaseMarkupSerializerextracted.java"

    # source_class = "Piece"
    # moved_methods = ['setX']
    # moved_fields = []
    # father_path_file = "C:\\Users\\asus\\Desktop\\benchmark_projects\\Chess_master\\src\\game\\Piece.java"
    # father_path_directory = "C:\\Users\\asus\\Desktop\\benchmark_projects\\Chess_master\\src\\game"
    # path_to_refactor = "C:\\Users\\asus\\Desktop\\benchmark_projects\\Chess_master"
    # new_class_file = "C:\\Users\\asus\\Desktop\\benchmark_projects\\Chess_master\\src\\game\\Pieceextracted.java"

    stream = FileStream(father_path_file, encoding='utf8', errors='ignore')
    lexer = JavaLexer(stream)
    token_stream = CommonTokenStream(lexer)
    parser = JavaParserLabeled(token_stream)
    parser.getTokenStream()
    parse_tree = parser.compilationUnit()
    my_listener = ExtractSubClassRefactoringListener(common_token_stream=token_stream,
                                                     source_class=source_class,
                                                     new_class=source_class + "extracted",
                                                     moved_fields=moved_fields, moved_methods=moved_methods,
                                                     output_path=father_path_directory)
    walker = ParseTreeWalker()
    walker.walk(t=parse_tree, listener=my_listener)

    with open(father_path_file, mode='w', newline='') as f:
        f.write(my_listener.token_stream_rewriter.getDefaultText())

    extractJavaFilesAndProcess(path_to_refactor, father_path_file, new_class_file)

    for file in files_to_refactor:
        stream = FileStream(file, encoding='utf8', errors='ignore')
        lexer = JavaLexer(stream)
        token_stream = CommonTokenStream(lexer)
        parser = JavaParserLabeled(token_stream)
        parser.getTokenStream()
        parse_tree = parser.compilationUnit()

        my_listener = FindUsagesListener(common_token_stream=token_stream,
                                         source_class=source_class,
                                         new_class=source_class + "extracted",
                                         moved_fields=moved_fields, moved_methods=moved_methods,
                                         output_path=father_path_directory)

        # output_path=father_path_directory)

        walker = ParseTreeWalker()
        walker.walk(t=parse_tree, listener=my_listener)

        tmp_aul = my_listener.aul

        with open(file, mode='w', newline='') as f:
            f.write(my_listener.token_stream_rewriter.getDefaultText())

        # after find usages

        try:
            stream = FileStream(file, encoding='utf8', errors='ignore')
            lexer = JavaLexer(stream)
            token_stream = CommonTokenStream(lexer)
            parser = JavaParserLabeled(token_stream)
            parser.getTokenStream()
            parse_tree = parser.compilationUnit()

            my_listener = PropagationListener(common_token_stream=token_stream,
                                              source_class=source_class,
                                              new_class=source_class + "extracted",
                                              moved_fields=moved_fields, moved_methods=moved_methods,
                                              output_path=father_path_directory, aul=tmp_aul)

            walker = ParseTreeWalker()
            walker.walk(t=parse_tree, listener=my_listener)

            with open(file, mode='w', newline='') as f:
                f.write(my_listener.token_stream_rewriter.getDefaultText())
        except:
            print("not utf8")