zoukankan      html  css  js  c++  java
  • EI328 Final Project Review (III)

      At the end of this project, we are required to implement a basic classifier on our own and make use of it in the Min-Max Modular Neural Network. Here we made a Logistic Regression Classifier with naive batch gradient descent optimization method.

      The discriminant function is a sigmoid function as follow:

          $y(vec x)={1+exp{-(vec w^Tcdotvec x+b) }}^{-1}$

      The objective function is the negative logarithm of likelihood:

          $-ln L=sum_{n=1}^l ln{1+exp{-t_ncdot(vec w^Tcdotvec x+b)}}$

      The gradient can be calculated as follow:

          $ abla_{vec w}(-ln L)=sum_{n=1}^l{y(vec x_n)-frac{t_n+1}{2}}cdot vec x_n$

      Here is the source code of ./src/DIY.java:

      1 /**
      2 * Argument Format:
      3 * <task # = 6-2 or 6-3> <number of threads> <min group size> <max group size> <group size step>
      4 **/
      5 
      6 import java.util.concurrent.atomic.*;
      7 import java.util.concurrent.*;
      8 import java.util.*;
      9 import java.io.*;
     10 
     11 class Node {
     12     public int idx;
     13     public double val;
     14     
     15     public Node(int idx,double val) {
     16         this.idx = idx;
     17         this.val = val;
     18     }
     19 }
     20 
     21 class DataSet {
     22     /** This class mimic Problem class in LIBLINEAR **/
     23     public int l, n;
     24     public List<List<Node>> x;
     25     public List<Integer> y;
     26     
     27     public DataSet(int l,int n) {
     28         x = new ArrayList<List<Node>>();
     29         y = new ArrayList<Integer>();
     30         this.l = l;
     31         this.n = n;
     32     }
     33     public DataSet(String path) {
     34         x = new ArrayList<List<Node>>();
     35         y = new ArrayList<Integer>();
     36         StringTokenizer tok = null;
     37         StringTokenizer tok1 = null;
     38         Scanner in = null;
     39         try {
     40             in = new Scanner(new File(path));
     41             for (;in.hasNextLine();l++) {
     42                 List<Node> tmp = new ArrayList<Node>();
     43                 tok = new StringTokenizer(in.nextLine());
     44                 y.add(Integer.valueOf(tok.nextToken()));
     45                 while (tok.hasMoreTokens()) {
     46                     tok1 = new StringTokenizer(tok.nextToken(),":");
     47                     int idx = Integer.parseInt(tok1.nextToken());
     48                     double val = Double.parseDouble(tok1.nextToken());
     49                     tmp.add(new Node(idx,val));
     50                     if (idx>n) {
     51                         n = idx;
     52                     }
     53                 }
     54                 x.add(tmp);
     55             }
     56             in.close();
     57         } catch (Exception e) {
     58             System.err.println("DATASET Error: "+e);
     59         }
     60     }
     61     public int getY(int idx) {
     62         if (idx<0||idx>=l) {
     63             throw new RuntimeException("DataSet: IndexOutOfBound");
     64         }
     65         return y.get(idx).intValue();
     66     }
     67 }
     68 
     69 class Augur {
     70     /** This class mimic Model class in LIBLINEAR **/
     71     private static int id = 0;
     72     private final double lamb = 0;
     73     private DataSet prob;
     74     private double[] w;
     75     private double b;
     76     private int n;
     77     
     78     public Augur() {}
     79     public Augur(DataSet prob) {
     80         this.prob = prob;
     81         n = prob.n;
     82         w = new double[n];
     83         double[] g = new double[n+1];
     84         for (int itr=0;itr<1000;itr++) {
     85             calGrad(g);
     86             double len = norm(g);
     87             if (itr%20==0) {
     88                 System.out.println("itr "+itr+": norm(g) = "+len);
     89             }
     90             //double step = optStep(g,len);
     91             double step = len;
     92             for (int i=0;i<n;i++) {
     93                 w[i] -= step*g[i];
     94             }
     95             b -= step*g[n];
     96         }
     97         System.out.println("	Group "+(++id)+" Done");
     98     }
     99     public int predict(List<Node> x) {
    100         return (sigmoid(x)>0.5)? 1:-1;
    101     }
    102     private double optStep(double[] g,double len) {
    103         double[] val = new double[2];
    104         double p = 0, r = len;
    105         double q1 = .382*r-.618*p;
    106         double q2 = .618*r+.382*p;
    107         val[0] = error(q1,g);
    108         val[1] = error(q2,g);
    109         for (int i=0;i<5;i++) {
    110             if (val[0]<val[1]) {
    111                 r = q2;
    112                 q2 = q1;
    113                 q1 = .382*r-.618*p;
    114                 val[1] = val[0];
    115                 val[0] = error(q1,g);
    116             } else {
    117                 p = q1;
    118                 q1 = q2;
    119                 q2 = .618*r+.382*p;
    120                 val[0] = val[1];
    121                 val[1] = error(q2,g);
    122             }
    123         }
    124         return (val[0]<val[1])? q1:q2;
    125     }
    126     private double sigmoid(List<Node> x) {
    127         double val = b;
    128         for (Node node: x) {
    129             if (node.idx<=n) {
    130                 val += w[node.idx-1]*node.val;
    131             }
    132         }
    133         return 1./(1+Math.exp(-val));
    134     }
    135     private double error(double step,double[] g) {
    136         double val = .5*lamb*square(w,b);
    137         for (int i=0;i<prob.l;i++) {
    138             double tmp = b-step*g[n];
    139             for (Node node:prob.x.get(i)) {
    140                 tmp += (w[node.idx-1]-step*g[node.idx-1])*node.val;
    141             }
    142             val += Math.log(1+Math.exp(-prob.getY(i)*tmp));
    143         }
    144         return val;
    145     }
    146     private void calGrad(double[] g) {
    147         for (int i=0;i<n;i++) {
    148             g[i] *= w[i]*lamb;
    149         }
    150         g[n] = b*lamb;
    151         for (int i=0;i<prob.l;i++) {
    152             double delt = sigmoid(prob.x.get(i))-(prob.getY(i)+1)/2;
    153             for (Node node:prob.x.get(i)) {
    154                 g[node.idx-1] += node.val*delt;
    155             }
    156             g[prob.n] += delt;
    157         }
    158     }
    159     private double square(double[] vect,double cst) {
    160         double val = cst*cst;
    161         for (int i=0;i<vect.length;i++) {
    162             val += vect[i]*vect[i];
    163         }
    164         return Math.sqrt(val);
    165     }
    166     private double norm(double[] vect) {
    167         return Math.sqrt(square(vect,0));
    168     }
    169 }
    170 
    171 class AbstractDIY extends Thread {
    172     protected static DataSet train, test;
    173     protected static PrintWriter out;
    174     protected static long start, end;
    175     
    176     protected static void init() {
    177         try {
    178             System.out.println();
    179             if (!(new File("./data/new_train.txt")).exists()) {
    180                 System.out.println("Preprocessing: please wait a minute!");
    181                 preproc("./data/","train.txt","");
    182                 preproc("./data/","test.txt","5001:1.00");
    183                 System.out.println("Ready!");
    184             }
    185             System.out.println("Reading the files ...	Wait please ~ ~");
    186             train = new DataSet("./data/new_train.txt");
    187             test = new DataSet("./data/new_test.txt");
    188         } catch (Exception e) {
    189             System.out.println("INIT Error: "+e);
    190         }
    191     }
    192     protected static void print(String str) {
    193         try {
    194             System.out.print(str);
    195             out.print(str);
    196         } catch (Exception e) {
    197             System.err.println("PRINT Error: "+e);
    198         }
    199     }
    200     protected static void println(String str) {
    201         try {
    202             System.out.println(str);
    203             out.println(str);
    204         } catch (Exception e) {
    205             System.err.println("PRINTLN Error: "+e);
    206         }
    207     }
    208     protected static void printTime(long time,boolean train) {
    209         try {
    210             if (train) {
    211                 println("	Training:	"+time+"ms elapsed");
    212             } else {
    213                 println("	Testing:	"+time+"ms elapsed");
    214             }
    215         } catch (Exception e) {
    216             System.err.println("PRINTTIME Error: "+e);
    217         }
    218     }
    219     protected static void stats(int[] res) {
    220         try {
    221             int truePos=0, falsePos=0, falseNeg=0, trueNeg=0;
    222             for (int i=0;i<test.l;i++) {
    223                 if (test.getY(i)>0) {
    224                     if (res[i]>0) {
    225                         truePos++;
    226                     } else {
    227                         falseNeg++;
    228                     }
    229                 } else {
    230                     if (res[i]>0) {
    231                         falsePos++;
    232                     } else {
    233                         trueNeg++;
    234                     }
    235                 }
    236             }
    237             System.out.println("	"+truePos+"	"+falsePos+"	"+falseNeg+"	"+trueNeg);
    238             double acc = (truePos+trueNeg+.0)/test.l;
    239             println("	acc	= "+acc);
    240             double p = (truePos+.0)/(truePos+falsePos);
    241             double r = (truePos+.0)/(truePos+falseNeg);
    242             double f1 = 2*r*p/(r+p);
    243             println("	F1	= "+f1);
    244             double tpr = (truePos+.0)/(truePos+falseNeg);
    245             double fpr = (falsePos+.0)/(falsePos+trueNeg);
    246             println("	TPR	= "+tpr);
    247             println("	FPR	= "+fpr);
    248             println("");
    249         } catch (Exception e) {
    250             System.err.println("STATS Error: "+e);
    251         }
    252     }
    253     protected static void preproc(String dir,String filename,String tail) throws IOException {
    254         Scanner fin = new Scanner(new FileInputStream(dir+filename));
    255         PrintWriter fout = new PrintWriter(new FileOutputStream(dir+"new_"+filename));
    256         int cnt=0, total=0;
    257         while (fin.hasNextLine()) {
    258             StringTokenizer tok = new StringTokenizer(fin.nextLine());
    259             if (tok.nextToken().charAt(0)=='A') {
    260                 fout.print("1 ");
    261                 cnt++;
    262             } else {
    263                 fout.print("-1 ");
    264             }
    265             while (tok.hasMoreTokens()) {
    266                 fout.print(tok.nextToken()+" ");
    267             }
    268             fout.println(tail);
    269             total++;
    270         }
    271         System.out.println(filename+":	"+cnt+"/"+total);
    272         fout.close();
    273         fin.close();
    274     }
    275 }
    276 
    277 
    278 public class DIY extends AbstractDIY {
    279     private static int NUM;
    280     private static int NUM_OF_THREAD;
    281     
    282     private static Semaphore waitTask;
    283     private static Semaphore nextTask;
    284     private static Semaphore timeMutex;
    285     private static Subprob probSentinel;
    286     private static Augur modSentinel;
    287     
    288     private static BlockingQueue<Subprob> probBuf;
    289     private static BlockingQueue<Augur> modBuf;
    290     private static BlockingQueue<Integer> minIdx;
    291     
    292     private static AtomicInteger[][] minResult;
    293     
    294     static {
    295         try {
    296             init();
    297             waitTask = new Semaphore(0);
    298             nextTask = new Semaphore(0);
    299             timeMutex = new Semaphore(1);
    300             probSentinel = new Subprob(-1);
    301             modSentinel = new Augur();
    302             probBuf = new LinkedBlockingQueue<Subprob>();
    303             modBuf = new LinkedBlockingQueue<Augur>();
    304             minIdx = new LinkedBlockingQueue<Integer>();
    305         } catch (Exception e) {
    306             System.err.println("STATIC Error: "+e);
    307         }
    308     }
    309     private static void separate(List<Integer> poslst,List<Integer> neglst,boolean prior) {
    310         // List the positive samples and negative samples:
    311         // Precondition: poslst and neglst are non-null empty lists
    312         // Postcondition: poslst and neglst are filled with indices of positive
    313         //            and negative training data respectively
    314         if (!prior) {
    315             for (int i=0;i<train.l;i++) {
    316                 if (train.getY(i)>0) {
    317                     poslst.add(new Integer(i));
    318                 } else {
    319                     neglst.add(new Integer(i));
    320                 }
    321             }
    322             Random rand = new Random();
    323             Collections.shuffle(poslst,rand);
    324             Collections.shuffle(neglst,rand);
    325         } else {
    326             try {
    327                 System.out.println("Gathering Prior Knowledge ... ");
    328                 BufferedReader pin = new BufferedReader(new FileReader("./data/train.txt"));
    329                 List<DataItem> posData = new ArrayList<DataItem>();
    330                 List<DataItem> negData = new ArrayList<DataItem>();
    331                 for (int i=0;i<train.l;i++) {
    332                     String line = pin.readLine();
    333                     if (train.getY(i)>0) {
    334                         posData.add(new DataItem(i,line.substring(0,3)));
    335                     } else {
    336                         negData.add(new DataItem(i,line.substring(0,3)));
    337                     }
    338                 }
    339                 pin.close();
    340                 //System.in.read();
    341                 Collections.sort(posData);
    342                 Collections.sort(negData);
    343                 for (int i=0;i<posData.size();i++) {
    344                     poslst.add(posData.get(i).getValue());
    345                 }
    346                 for (int i=0;i<negData.size();i++) {
    347                     neglst.add(negData.get(i).getValue());
    348                 }
    349                 System.out.println("Ready!");
    350             } catch (Exception e) {
    351                 System.err.println("SEPARATE Error: "+e);
    352             }
    353         }
    354     }
    355     private static void distribute(boolean prior) {
    356         List<Integer> poslst = new ArrayList<Integer>();
    357         List<Integer> neglst = new ArrayList<Integer>();
    358         separate(poslst,neglst,prior);
    359         // Group the positive samples and negative samples:
    360         int posGrpNum = (poslst.size()+NUM-1)/NUM;
    361         int negGrpNum = (neglst.size()+NUM-1)/NUM;
    362         int[] posGrps = new int[posGrpNum+1];
    363         int[] negGrps = new int[negGrpNum+1];
    364         System.out.println("	Totally "+posGrpNum*negGrpNum+" groups:");
    365         for (int i=0;i<posGrpNum;i++) {
    366             posGrps[i+1] = posGrps[i]+(poslst.size()+i)/posGrpNum;
    367         }
    368         for (int i=0;i<negGrpNum;i++) {
    369             negGrps[i+1] = negGrps[i]+(neglst.size()+i)/negGrpNum;
    370         }
    371         try {    // Add tasks to the buffer:
    372             for (int i=0;i<posGrpNum;i++) {
    373                 for (int j=0;j<negGrpNum;j++) {
    374                     Subprob sub = new Subprob(i);
    375                     for (int k=posGrps[i];k<posGrps[i+1];k++) {
    376                         sub.add(poslst.get(k));
    377                     }
    378                     for (int k=negGrps[j];k<negGrps[j+1];k++) {
    379                         sub.add(neglst.get(k));
    380                     }
    381                     probBuf.put(sub);
    382                 }
    383             }
    384         } catch (Exception e) {
    385             System.err.println("DISTRIBUTE Error 1: "+e);
    386         }
    387         try {    // Prepare for the MIN modules:
    388             minResult = new AtomicInteger[posGrpNum][test.l];
    389             for (int i=0;i<posGrpNum;i++) {
    390                 for (int j=0;j<test.l;j++) {
    391                     minResult[i][j] = new AtomicInteger(1);
    392                 }
    393             }
    394             probBuf.put(probSentinel);    // "STOP TRAINING" signal
    395         } catch (Exception e) {
    396             System.err.println("DISTRIBUTE Error 2: "+e);
    397         }
    398     }
    399     private static void main_thread(boolean prior) {
    400         try {
    401             start = System.currentTimeMillis();        // start training
    402             end = -1;                                // still pending
    403             DIY[] tsks = new DIY[NUM_OF_THREAD];
    404             for (int i=0;i<NUM_OF_THREAD;i++) {
    405                 tsks[i] = new DIY();
    406                 tsks[i].start();
    407             }
    408             distribute(prior);
    409             for (int i=0;i<NUM_OF_THREAD;i++) {
    410                 waitTask.acquire();                    // awaiting sub-threads
    411             }
    412             for (int i=0;i<NUM_OF_THREAD;i++) {
    413                 nextTask.release();                    // testing enabled
    414             }
    415             start = System.currentTimeMillis();        // start testing
    416             for (int i=0;i<NUM_OF_THREAD;i++) {
    417                 tsks[i].join();
    418             }
    419             int[] res = new int[test.l];
    420             Arrays.fill(res,-1);
    421             for (int i=0;i<minResult.length;i++) {
    422                 for (int j=0;j<test.l;j++) {
    423                     if (minResult[i][j].get()>0) {
    424                         res[j] = 1;                    // MAX Module
    425                     }
    426                 }
    427             }
    428             end = System.currentTimeMillis();        // finish testing
    429             printTime(end-start,false);
    430             stats(res);
    431         } catch (Exception e) {
    432             System.err.println("MAIN_THRAD Error: "+e);
    433         }
    434     }
    435     public static void main(String[] args) {
    436         try {
    437             out = new PrintWriter(new FileWriter("./result/task"+args[0]+"_result.out"));
    438             NUM_OF_THREAD = Integer.parseInt(args[1]);
    439             int numMin = Integer.parseInt(args[2]);
    440             int numMax = Integer.parseInt(args[3]);
    441             int numStep = Integer.parseInt(args[4]);
    442             for (NUM=numMin;NUM<=numMax;NUM+=numStep) {
    443                 print("Group Size = "+NUM+",	");
    444                 println(NUM_OF_THREAD+" Threads");
    445                 probBuf.clear();
    446                 modBuf.clear();
    447                 minIdx.clear();
    448                 if (args[0].equals("6-2")) {
    449                     main_thread(false);
    450                 } else {
    451                     main_thread(true);
    452                 }
    453             }
    454             out.close();
    455         } catch (Exception e) {
    456             System.err.println("MAIN Error: "+e);
    457         }
    458     }
    459     
    460     public void run() {
    461         try {
    462             train();
    463             waitTask.release();
    464             nextTask.acquire();
    465             test();
    466         } catch (Exception e) {
    467             System.err.println("RUN Error: "+e);
    468         }
    469     }
    470     private void train() {
    471         Subprob sub = null;
    472         try {
    473             while (true) {
    474                 sub = probBuf.take();
    475                 if (sub==probSentinel) {    // signal of termination
    476                     timeMutex.acquire();
    477                     if (end<0) {            // finish training
    478                         end = System.currentTimeMillis();
    479                         printTime(end-start,true);
    480                     }
    481                     timeMutex.release();
    482                     probBuf.put(sub);
    483                     break;
    484                 }
    485                 DataSet prob = new DataSet(sub.size(),train.n);
    486                 for (int i=0;i<sub.size();i++) {
    487                     int pos = sub.getItem(i).intValue();
    488                     prob.x.add(train.x.get(pos));
    489                     prob.y.add(train.y.get(pos));
    490                 }
    491                 modBuf.put(new Augur(prob));
    492                 minIdx.put(sub.getIndex());
    493             }
    494             modBuf.put(modSentinel);
    495         } catch (Exception e) {
    496             System.err.println("TRAIN Error: "+e);
    497         }
    498     }
    499     private void test() {
    500         Augur model = null;
    501         int idx = -1;
    502         try {
    503             while (true) {
    504                 model = modBuf.take();
    505                 if (model==modSentinel) {    // signal of termination
    506                     modBuf.put(model);
    507                     break;
    508                 }
    509                 idx = minIdx.poll().intValue();
    510                 for (int i=0;i<test.l;i++) {
    511                     if (model.predict(test.x.get(i))<0) {
    512                         minResult[idx][i].getAndSet(-1);    // MIN Modules
    513                     }
    514                 }
    515             }
    516         } catch (Exception e) {
    517             System.err.println("TEST Error:"+e);
    518         }
    519     }
    520 }

      The final result is quite reassuring in terms of F1 value and accuracy, whereas time performance leaves much to be desired:

  • 相关阅读:
    读完此文让你了解各个queue的原理
    借汇编之力窥探String背后的数据结构奥秘
    汇编高手带你玩转字符串,快上车!
    语雀调研
    产品技能一:抽象能力
    我所认知的敏捷开发
    产品经理需要的技能,我有吗?
    孙正义采访:接下来的30年,一切将被重新定义
    5G小白鼠
    goto语句为啥不受待见
  • 原文地址:https://www.cnblogs.com/DevinZ/p/4496014.html
Copyright © 2011-2022 走看看