package ic.cg;


import ic.ast.Program;

import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashMap;

/**
 * A CodeGenerator generates the assembly code for a program,
 * after each MethodDecl has been annotated with TACLists.  There is
 * some basic functionality for dealing with string constants and
 * their labels.
 * <p>
 * Feel free to change this code in any way that you find useful,
 * or to ignore it completely and roll your own code generator from 
 * scratch.
 * <p> 
 * This code assumes that your top-level AST node is called
 * Program.  You may need to change this to match your own AST
 * hierarchy.  Usage:
 *
 * <pre>
 *   CodeGenerator cg = new CodeGenerator("file.ic", program);
 *   cg.generate();
 * </pre>
 */
public class CodeGenerator {
    
    /** The name of the assembly file being generated */
    protected final String asmFile;
    
    /** The writer used to print the assembly file */
    protected final PrintWriter out;
    
    /** The program being translated into assembly */
    protected final Program program;
    
    /**
     * Given a file name "file.ic" and a program AST, construct a
     * code generator to print "file.s", the assembly code version
     * of the program.
     */
    public CodeGenerator(String icFileName, Program program) throws IOException {
        asmFile = icFileName.substring(0, icFileName.indexOf('.')) + ".s";
        out = new PrintWriter(new FileWriter(asmFile));        
        this.program = program;
    }
    
    
    /**
     * The main method to call when you wish to perform the
     * translation.
     */
    public void generate() {

        out.println("# Compiled from " + asmFile);
        out.println();
        out.println();

        generateVTables();
        generateCode();
        generateErrorHandlers();

        // TODO: Provide the real information for main.

        String classForMain = "Simple";
        int sizeOfMainObject = 4;
        int offsetOfMainInDV = 0;
        generateMain(classForMain, sizeOfMainObject, offsetOfMainInDV);

        generateStringConstants();

        out.close();
    }

    /**
     * Print the VTable for all the classes.  You'll need to change this.
     *
     */
    protected void generateVTables() { 
        
        // TODO: You must implement this method 

        out.println("# ----------------------------");                
        out.println("# VTables");                

        out.println();
        out.println(".data");
        out.println(".align 4");
        out.println();

        out.println("_Simple_DV:");
        out.println("      .long _Simple_main");
        out.println("");
        out.println("");
        out.println("");
    }


    /**
     * Generate Simple.main method.  You'll need to change this.
     */
    protected void generateCode() { 
                
        // TODO: You must implement this method 

        out.println("# ----------------------------");
        out.println("# Code.  Just the Simple main method for now...");
        out.println("");
        out.println(".text");
        out.println(".align 4");
        out.println("_Simple_main:");
        out.println("");
        out.println("     # prologue");
        out.println("     pushl %ebp                    ");
        out.println("     mov %esp, %ebp                ");
        out.println("     subl $4, %esp                ");
        out.println("");
        out.println("     # store 1234 into a local variable.");
        out.println("     movl $1234, %ecx                     ");
        out.println("     movl %ecx, -4(%ebp)");
        out.println("");
        out.println("     # print that local var...");
        out.println("     pushl -4(%ebp)");
        out.println("     call __LIB_printi");
        out.println("     addl $4, %esp");
        out.println("");
        out.println("     # ... and a new line");
        out.printf ("     pushl $%s\n", 
                    labelForStringConstant("\na string \n containing \t\t quotes and escape characters: \"moo\"\n"));
        out.println("     call __LIB_println");
        out.println("     addl $4, %esp        ");
        out.println("        ");
        out.println("     # epilogue");
        out.println("     mov %ebp, %esp                ");
        out.println("     pop %ebp                      ");
        out.println("     ret                           ");
        out.println("");
        out.println("");
        out.println("");
    }

    /**
     * Print out the assembly code to print run-time errors and
     * exit gracefully.  You should jump to these labels on
     * run-time check failure.  You should not need to change this
     * method.
     */
    protected void generateErrorHandlers() {
        out.println("# ----------------------------");
        out.println("# Error handling.  Jump to these procedures when a run-time check fails.");
        out.println("");
        out.println(".data");
        out.println(".long 23");
        out.println("strNullPtrError:     .ascii \"Null pointer violation.\"");
        out.println(".long 23");
        out.println("strArrayBoundsError: .ascii \"Array bounds violation.\"");
        out.println(".long 21");
        out.println("strArraySizeError:   .ascii \"Array size violation.\"");
        out.println("");

	out.println(".text");
        out.println(".align 4");
        out.println("labelNullPtrError:");
        out.println("    push $strNullPtrError");
        out.println("    call __LIB_println");
        out.println("    push $1");
        out.println("    call __LIB_exit");
        out.println("");
        out.println(".align 4");
        out.println("labelArrayBoundsError:");
        out.println("    push $strArrayBoundsError");
        out.println("    call __LIB_println");
        out.println("    push $1");
        out.println("    call __LIB_exit");
        out.println("");
        out.println(".align 4");
        out.println("labelArraySizeError:");
        out.println("    push $strArraySizeError");
        out.println("    call __LIB_println");
        out.println("    push $1");
        out.println("    call __LIB_exit");
        out.println("");
        out.println("");
        out.println("");
    }

    /**
     * Generate the __ic_main stub that creates and calls main on
     * the right object.  You should not need to change this
     * method.
     * @param className                name of the class containing main
     * @param objectSize               size of objects of that class
     * @param indexOfMainInVTable      index on that class's vtable for main
     */
    protected void generateMain(String className, int objectSize, int indexOfMainInVTable) {
        out.println("# The main entry point.  Allocate object and invoke main on it.");
        out.println("");
	out.println(".text");
        out.println(".align 4");
        out.println(".globl __ic_main");
        out.println("__ic_main:");
        out.println("       pushl %ebp                        # prologue");
        out.println("       movl %esp,%ebp                ");
        out.println("");
        out.printf ("       pushl $%-4d                       # o = new %s\n", 
                    objectSize, className);
        out.println("       call __LIB_allocateObject   ");
        out.println("       addl $4, %esp                ");
        out.printf ("       movl $_%s_DV, (%%eax)       \n", 
                    className);
        out.println("       pushl 8(%ebp)                     # o.main(args)");
        out.println("       pushl %eax                   ");
        out.println("       movl (%eax), %eax            ");
        out.printf ("       call *%d(%%eax)                     # main is at offset %d in vtable\n", 
                    indexOfMainInVTable * 4, indexOfMainInVTable);
        out.println("       addl $8, %esp                ");
        out.println("       movl $0, %eax                     # __ic_main always returns 0");
        out.println("");
        out.println("       movl %ebp,%esp                    # epilogue");
        out.println("       popl %ebp                    ");
        out.println("       ret                         ");
        out.println("");
        out.println("");
        out.println("");
    }

        
    /********************** String Constants *********************/
        
        
    /** A map from string constant to the assembly code label in
     * the data segment where that constant is stored.  See the
     * labelForStringConstant method.
     */
    protected final HashMap<String,String> stringConstantsToLabel = new HashMap<String,String>();

    /**
     * Return a unique label for a string constant.  After
     * translating all code and getting labels for all string
     * constants, the code generator will print out a data segment
     * containing the labels and string constants.  The string may
     * contain only the following escape characters: \n, \r, \t.
     */
    protected String labelForStringConstant(String stringConstant) {
        String label = stringConstantsToLabel.get(stringConstant);
        if (label == null) {
            label = "_str" + stringConstantsToLabel.size();
            stringConstantsToLabel.put(stringConstant, label);
        }
        return label;
    }

        
    /**
     * Iterate over all used string constants and print them into
     * a data segment.  You should not need to change this method.
     */
    protected void generateStringConstants() {
        out.println("# ----------------------------");                
        out.println("# String Constants");                

        out.println();
        out.println(".data");
        out.println(".align 4");
        out.println();

        for (String s : stringConstantsToLabel.keySet()) {
            String label = stringConstantsToLabel.get(s);
            int len = s.length();
            out.printf(".long %s\n", len);
            String escapedString = 
                s.replace("\n", "\\n").replace("\t", "\\t").replace("\"", "\\\"").replace("\r", "\\r");
            out.printf("%s:\t.ascii \"%s\"\n", label, escapedString);
        }
        out.println("");
        out.println("");
        out.println("");
    }
}

