/*************************************************** Assignment 3-part1 - Spring 2005 Course: CS 182 CS182, Assignment 3: Backpropagation TesterPart1.java **************************************************/ /** * Class for testing backpropagation networks. */ public class TesterPart1 { static int[][] data = {{0,0},{0,1},{1,0},{1,1}}; static int[][] andTargets = {{0},{0},{0},{1}}; static int[][] orTargets = {{0},{1},{1},{1}}; static int[][] sameTargets = {{1},{0},{0},{1}}; /** * Create a network for learning the AND function. */ static public Net createAnd () { return new Net(data, andTargets); } // network for leaning OR function static public Net createOr () { return new Net(data, orTargets); } /** * Create a network for learning the SAME function. */ //static public Net createSame (){ // return new Net(5, 2, 1, data, sameTargets); //} /** * Create a network w/o hidden layers that won't learning the SAME function. */ static public Net createBadSame (){ return new Net(data, sameTargets); } /** * Create network for leaning N-input, M-hidden node auto encoder */ /* static public Net createAutoCoder(int inputs, int hidden) { int patterns[][] = new int[inputs] [inputs]; for (int i=0; i= outUnit.inWeights[0] && Net.MIN_WEIGHT <= outUnit.inWeights[1] && Net.MAX_WEIGHT >= outUnit.inWeights[1] && Net.MIN_WEIGHT <= outUnit.inWeights[2] && Net.MAX_WEIGHT >= outUnit.inWeights[2] ) { System.out.print("PASSED \n"); } else { System.out.print("FAILED \n"); } System.out.print(" computeActivation() Test: "); outUnit.inWeights[0] = 0.5; outUnit.inWeights[1] = 0.5; outUnit.inWeights[2] = 0.5; inUnit1.activation = 1.0; inUnit2.activation = 1.0; outUnit.computeActivation(); if( outUnit.activation >= 0.8175 && outUnit.activation <= 0.8176){//1.5 / (1.0 + Math.exp(1.5) ) ) { System.out.print("PASSED \n"); } else { System.out.print("FAILED with activation " + outUnit.activation + " expecting activation 0.8175 \n"); } System.out.print(" computeError() Test: "); outUnit.computeError(0); if( outUnit.error >= -0.8176 && outUnit.error <= -0.8175){//1.5 / (1.0 + Math.exp(1.5) ) ) { System.out.print("PASSED \n"); } else { System.out.print("FAILED with error " + outUnit.error + " expecting error -0.8175 \n"); } System.out.print(" computeWeightChange() Test: "); outUnit.computeWeightChange(); if( outUnit.weightChange[0] >= -0.01220 && outUnit.weightChange[0] <= -0.01219 && outUnit.weightChange[1] >= -0.01220 && outUnit.weightChange[1] <= -0.01219 && outUnit.weightChange[2] >= -0.01220 && outUnit.weightChange[2] <= -0.01219) { System.out.print("PASSED \n"); } else { System.out.print("FAILED \n"); } System.out.print(" updateWeights() Test: "); outUnit.updateWeights(); if( outUnit.inWeights[0] >= 0.48780 && outUnit.inWeights[0] <= 0.48781 && outUnit.inWeights[1] >= 0.48780 && outUnit.inWeights[1] <= 0.48781 && outUnit.inWeights[2] >= 0.48780 && outUnit.inWeights[2] <= 0.48781) { System.out.print("PASSED \n"); } else { System.out.print("FAILED \n"); } System.out.print("\nTesting Net.java Class: \n"); System.out.print(" Net constructor Test: "); Net n; n = createAnd(); boolean netCreated = false; if(n.outUnit.in.contains(n.inUnit1) && n.outUnit.in.contains(n.inUnit2) && n.outUnit.in.contains(n.Bias)) { netCreated = true; } if(netCreated) { System.out.print("PASSED \n"); } else { System.out.print("FAILED \n"); } System.out.print(" feedforward() Test: "); n.outUnit.inWeights[0] = 0.5; n.outUnit.inWeights[1] = 0.5; n.outUnit.inWeights[2] = 0.5; n.feedforward( data[1] ); if( outUnit.activation >= 0.8175 && outUnit.activation <= 0.8176){//1.5 / (1.0 + Math.exp(1.5) ) ) { System.out.print("PASSED \n"); } else { System.out.print("FAILED with activation " + outUnit.activation + " expecting activation 0.8175 \n"); } System.out.print(" computeError() Test1: "); n.outUnit.inWeights[0] = 20; n.outUnit.inWeights[1] = 20; n.outUnit.inWeights[2] = -30; double err = n.computeError(); if(err > 0.0 && err < 0.0001) { System.out.print("PASSED \n"); } else { System.out.print("FAILED \n"); } System.out.print(" Test2: "); n.outUnit.inWeights[0] = -20; n.outUnit.inWeights[1] = -20; n.outUnit.inWeights[2] = 30; err = n.computeError(); if(err > 1.99 && err < 2.0) { System.out.print("PASSED \n"); } else { System.out.print("FAILED \n"); } System.out.print("\n AND train() Test: "); n.train(); if(n.outUnit.inWeights[0] + n.outUnit.inWeights[2] < 0 && n.outUnit.inWeights[1] + n.outUnit.inWeights[2] < 0 && n.outUnit.inWeights[0] + n.outUnit.inWeights[1] + n.outUnit.inWeights[2] > 0) { System.out.print("PASSED \n\n"); } else { System.out.print(n.toString()); System.out.print("FAILED \n\n"); } System.out.print(" OR train() Test: "); n = createOr(); n.train(); if(n.outUnit.inWeights[0] + n.outUnit.inWeights[2] > 0 && n.outUnit.inWeights[1] + n.outUnit.inWeights[2] > 0 && n.outUnit.inWeights[2] < 0) { System.out.print("PASSED \n\n"); } else { System.out.print(n.toString()); System.out.print("FAILED \n\n"); } } else if (argv.length == 5) { Net n; try { if (argv[0].equalsIgnoreCase("AND")) { n = createAnd(); } else if (argv[0].equalsIgnoreCase("OR")) { n = createOr(); } else if (argv[0].equalsIgnoreCase("SAME")) { n = createBadSame(); } else if (argv[0].equalsIgnoreCase("BADSAME")) { n = createBadSame(); //} else if (argv[0].equalsIgnoreCase("AUTO8")) { // n = createAutoCoder(8,3); // } else if (argv[0].equalsIgnoreCase("BADAUTO9")) { // n = createAutoCoder(9,3); // } else if (argv[0].equalsIgnoreCase("GOODAUTO9")) { // n = createAutoCoder(9,5); //} else if (argv[0].equalsIgnoreCase("STRESS")) { // n = createAutoCoder(256,10); } else { System.err.println("Unknown function name"); return; }; } catch (Exception e) { System.out.println(">>>>>>>>>>>> Failed to create Net: " + argv[0] + ": " + e.getMessage()); return; }; try { n.setTrainingParameters(Integer.parseInt(argv[1]), Double.parseDouble(argv[2]), Double.parseDouble(argv[3]), Double.parseDouble(argv[4])); } catch (Exception e) { System.out.println(">>>>>>>>>>>>>>>> Failed to set Parameters: " + argv[0] + ": " + e.getMessage()); return; }; try { n.train(); } catch (Exception e) { System.out.println(">>>>>>>>>>>>>>>>> Failed to train: " + argv[0] + ":" + e.getMessage()); return; } System.out.print(" Weights for " + argv[0] + ": \n"); System.out.print(n.toString()); } else { System.err.println("Invalid argument count"); return; } // System.out.print(" Weights for " + argv[0] + ": \n"); /* // dump the activations for each input for (int r = 0; r