All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.patterns.surface.PatternsForEachTokenDB Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.patterns.surface;

import edu.stanford.nlp.patterns.Pattern;
import edu.stanford.nlp.patterns.SQLConnection;
import edu.stanford.nlp.util.ArgumentParser;
import edu.stanford.nlp.util.logging.Redwood;

import java.io.*;
import java.sql.*;
import java.util.*;

/**
 * Created by sonalg on 10/22/14.
 */
public class PatternsForEachTokenDB extends PatternsForEachToken{


  @ArgumentParser.Option(name = "createTable")
  boolean createTable = false;

  @ArgumentParser.Option(name = "deleteExisting")
  boolean deleteExisting = false;

  @ArgumentParser.Option(name = "tableName")
  String tableName = null;

  @ArgumentParser.Option(name = "patternindicesTable")
  String patternindicesTable = "patternindices";

  @ArgumentParser.Option(name="deleteDBResourcesOnExit")
  boolean deleteDBResourcesOnExit = true;

  public PatternsForEachTokenDB(Properties props, Map>> pats){

    ArgumentParser.fillOptions(this, props);


      ArgumentParser.fillOptions(SQLConnection.class, props);

      assert tableName != null : "tableName property is null!";
      tableName = tableName.toLowerCase();
      if (createTable && !deleteExisting)
        throw new RuntimeException("Cannot have createTable as true and deleteExisting as false!");
      if (createTable){
        createTable();
        createUpsertFunction();
      }else{
        assert DBTableExists() : "Table " + tableName + " does not exists. Pass createTable=true to create a new table";
      }


    if(pats != null)
      addPatterns(pats);
  }

  public PatternsForEachTokenDB(Properties props) {
    this(props, null);
  }

  void createTable() {
    String query ="";
    try {
      Connection conn = SQLConnection.getConnection();
      if(DBTableExists()){
        if (deleteExisting) {
          System.out.println("deleting table " + tableName);
          Statement stmt = conn.createStatement();
          query = "drop table " + tableName;
          stmt.execute(query);
          stmt.close();
          Statement stmtindex = conn.createStatement();
          query = "DROP INDEX IF EXISTS " + tableName+"_index";
          stmtindex.execute(query);
          stmtindex.close();
        }
      }
      System.out.println("creating table " + tableName);
      Statement stmt = conn.createStatement();
      //query = "create table  IF NOT EXISTS " + tableName + " (\"sentid\" text, \"tokenid\" int, \"patterns\" bytea); ";
      query = "create table IF NOT EXISTS " + tableName + " (sentid text, patterns bytea); ";
      stmt.execute(query);
      stmt.close();
      conn.close();} catch (SQLException e) {
      throw new RuntimeException("Error executing query " + query + "\n" + e);
    }
  }

  @Override
  public void addPatterns(Map>> pats){
    try {
      Connection conn = null;
      PreparedStatement pstmt = null;


      conn = SQLConnection.getConnection();
      pstmt = getPreparedStmt(conn);


      for (Map.Entry>> en : pats.entrySet()) {
        addPattern(en.getKey(), en.getValue(), pstmt);

        pstmt.addBatch();
      }


      pstmt.executeBatch();
      conn.commit();
      pstmt.close();
      conn.close();
    }catch(SQLException e){
      throw new RuntimeException(e);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }


  public void addPatterns(String id, Map> p){
    try {
    PreparedStatement pstmt = null;
    Connection conn= null;


      conn = SQLConnection.getConnection();
      pstmt = getPreparedStmt(conn);

    addPattern(id, p, pstmt);


      pstmt.execute();
      conn.commit();

    pstmt.close();
      conn.close();
    } catch (SQLException e) {
      throw new RuntimeException(e);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  /*
  public void addPatterns(String id, Map> p, PreparedStatement pstmt) throws IOException, SQLException {
    for (Map.Entry> en2 : p.entrySet()) {
      addPattern(id, en2.getKey(), en2.getValue(), pstmt);
      if(useDBForTokenPatterns)
        pstmt.addBatch();
    }
  }
*/

/*
  public void addPatterns(String sentId, int tokenId, Set patterns) throws SQLException, IOException{
    PreparedStatement pstmt = null;
    Connection conn= null;
    if(useDBForTokenPatterns) {
      conn = SQLConnection.getConnection();
      pstmt = getPreparedStmt(conn);
    }

    addPattern(sentId, tokenId, patterns, pstmt);

    if(useDBForTokenPatterns){
      pstmt.execute();
      conn.commit();
      pstmt.close();
      conn.close();
    }
  }
  */

  /*
  private void addPattern(String sentId, int tokenId, Set patterns, PreparedStatement pstmt) throws SQLException, IOException {

    if(pstmt != null){
//      ByteArrayOutputStream baos = new ByteArrayOutputStream();
//      ObjectOutputStream oos = new ObjectOutputStream(baos);
//      oos.writeObject(patterns);
//      byte[] patsAsBytes = baos.toByteArray();
//      ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes);
//      pstmt.setBinaryStream(1, bais, patsAsBytes.length);
//      pstmt.setObject(2, sentId);
//      pstmt.setInt(3, tokenId);
//      pstmt.setString(4,sentId);
//      pstmt.setInt(5, tokenId);
//      ByteArrayOutputStream baos2 = new ByteArrayOutputStream();
//      ObjectOutputStream oos2 = new ObjectOutputStream(baos2);
//      oos2.writeObject(patterns);
//      byte[] patsAsBytes2 = baos2.toByteArray();
//      ByteArrayInputStream bais2 = new ByteArrayInputStream(patsAsBytes2);
//      pstmt.setBinaryStream(6, bais2, patsAsBytes2.length);
//      pstmt.setString(7,sentId);
//      pstmt.setInt(8, tokenId);

      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      ObjectOutputStream oos = new ObjectOutputStream(baos);
      oos.writeObject(patterns);
      byte[] patsAsBytes = baos.toByteArray();
      ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes);
      pstmt.setBinaryStream(3, bais, patsAsBytes.length);
      pstmt.setObject(1, sentId);
      pstmt.setInt(2, tokenId);


    } else{
      if(!patternsForEachToken.containsKey(sentId))
        patternsForEachToken.put(sentId, new ConcurrentHashMap>());
      patternsForEachToken.get(sentId).put(tokenId, patterns);
    }
  }*/


  private void addPattern(String sentId, Map> patterns, PreparedStatement pstmt) throws SQLException, IOException {

    if(pstmt != null){
      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      ObjectOutputStream oos = new ObjectOutputStream(baos);
      oos.writeObject(patterns);
      byte[] patsAsBytes = baos.toByteArray();
      ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes);
      pstmt.setBinaryStream(2, bais, patsAsBytes.length);
      pstmt.setObject(1, sentId);
      //pstmt.setInt(2, tokenId);


    }
  }


  public void createUpsertFunction() {
    try{
    Connection conn = SQLConnection.getConnection();
    String s = "CREATE OR REPLACE FUNCTION upsert_patterns(sentid1 text, pats1 bytea) RETURNS VOID AS $$\n" +
      "DECLARE\n" +
      "BEGIN\n" +
      "    UPDATE " + tableName+ " SET patterns = pats1 WHERE sentid = sentid1;\n" +
      "    IF NOT FOUND THEN\n" +
      "    INSERT INTO " + tableName + "  values (sentid1, pats1);\n" +
      "    END IF;\n" +
      "END;\n" +
      "$$ LANGUAGE 'plpgsql';\n";
    Statement st = conn.createStatement();
    st.execute(s);
    conn.close();}catch(SQLException e){
      throw new RuntimeException(e);
    }
  }

  public void createUpsertFunctionPatternIndex() throws SQLException {
    Connection conn = SQLConnection.getConnection();
    String s = "CREATE OR REPLACE FUNCTION upsert_patternindex(tablename1 text, index1 bytea) RETURNS VOID AS $$\n" +
      "DECLARE\n" +
      "BEGIN\n" +
      "    UPDATE " + patternindicesTable + " SET index = index1 WHERE  tablename = tablename1;\n" +
      "    IF NOT FOUND THEN\n" +
      "    INSERT INTO " + patternindicesTable + "  values (tablename1, index1);\n" +
      "    END IF;\n" +
      "END;\n" +
      "$$ LANGUAGE 'plpgsql';\n";
    Statement st = conn.createStatement();
    st.execute(s);
    conn.close();
  }






  private PreparedStatement getPreparedStmt(Connection conn) throws SQLException {
    conn.setAutoCommit(false);
    //return conn.prepareStatement("UPDATE " + tableName + " SET patterns = ? WHERE sentid = ? and tokenid = ?; " +
    //  "INSERT INTO " + tableName + " (sentid, tokenid, patterns) (SELECT ?,?,? WHERE NOT EXISTS (SELECT sentid FROM " + tableName + " WHERE sentid  =? and tokenid=?));");
    //  return conn.prepareStatement("INSERT INTO " + tableName + " (sentid, tokenid, patterns) (SELECT ?,?,? WHERE NOT EXISTS (SELECT sentid FROM " + tableName + " WHERE sentid  =? and tokenid=?))");
    return conn.prepareStatement("select upsert_patterns(?,?)");
  }




/*
  public Set getPatterns(String sentId, Integer tokenId) throws SQLException, IOException, ClassNotFoundException {
    if(useDBForTokenPatterns){
      Connection conn = SQLConnection.getConnection();

      String query = "Select patterns from " + tableName + " where sentid=\'" + sentId + "\' and tokenid = " + tokenId;
      Statement stmt = conn.createStatement();
      ResultSet rs = stmt.executeQuery(query);
      Set pats = null;
      if(rs.next()){
        byte[] st = (byte[]) rs.getObject(1);
        ByteArrayInputStream baip = new ByteArrayInputStream(st);
        ObjectInputStream ois = new ObjectInputStream(baip);
        pats = (Set) ois.readObject();

      }
      conn.close();
      return pats;
    }
    else
      return patternsForEachToken.get(sentId).get(tokenId);
  }*/



  @Override
  public Map> getPatternsForAllTokens(String sentId){
  try{
      Connection conn = SQLConnection.getConnection();
      //Map> pats = new ConcurrentHashMap>();
      String query = "Select patterns from " + tableName + " where sentid=\'" + sentId + "\'";
      Statement stmt = conn.createStatement();
      ResultSet rs = stmt.executeQuery(query);
      Map> patsToken = new HashMap<>();
      if(rs.next()){
        byte[] st = (byte[]) rs.getObject(1);
        ByteArrayInputStream baip = new ByteArrayInputStream(st);
        ObjectInputStream ois = new ObjectInputStream(baip);
        patsToken = (Map>) ois.readObject();
        //pats.put(rs.getInt("tokenid"), patsToken);
      }
      conn.close();
      return patsToken;
  }catch(SQLException e){
    throw new RuntimeException(e);
  } catch (ClassNotFoundException e) {
    throw new RuntimeException(e);
  } catch (IOException e) {
    throw new RuntimeException(e);
  }
  }

  @Override
  public boolean save(String dir) {
    //nothing to do
    return false;
  }

  @Override
  public void setupSearch() {
    //nothing to do
  }

  public boolean containsSentId(String sentId){
      try {
        Connection conn = SQLConnection.getConnection();
        String query = "Select tokenid from " + tableName + " where sentid=\'" + sentId + "\' limit 1";
        Statement stmt = conn.createStatement();
        ResultSet rs = stmt.executeQuery(query);

        boolean contains = false;

        while (rs.next()) {
          contains = true;
          break;
        }

        conn.close();
        return contains;
      }catch(SQLException e){
        throw new RuntimeException(e);
      }
  }

  @Override
  public void createIndexIfUsingDBAndNotExists(){
      try {
        Redwood.log(Redwood.DBG, "Creating index for " + tableName);
        Connection conn = SQLConnection.getConnection();
        Statement stmt = conn.createStatement();
        boolean doesnotexist = false;

        //check if the index already exists
        try{
          Statement stmt2 = conn.createStatement();
          String query = "SELECT '"+tableName+"_index'::regclass";
          stmt2.execute(query);
        }catch (SQLException e){
          doesnotexist = true;
        }

        if(doesnotexist){
          String indexquery ="create index CONCURRENTLY " + tableName +"_index on " + tableName+ " using hash(\"sentid\") ";
          stmt.execute(indexquery);
          Redwood.log(Redwood.DBG, "Done creating index for " + tableName);
        }
      } catch (SQLException e) {
        throw new RuntimeException(e);
      }

  }

//  /**
//   * not yet supported if backed by DB
//   * @return
//   */
//  public Set>>> entrySet() {
//    if(!useDBForTokenPatterns)
//      return patternsForEachToken.entrySet();
//    else
//      //not yet supported if backed by DB
//      throw new UnsupportedOperationException();
//  }

  public boolean DBTableExists() {
    try {
      Connection conn = null;

      conn = SQLConnection.getConnection();

      DatabaseMetaData dbm = conn.getMetaData();
      ResultSet tables = dbm.getTables(null, null, tableName, null);
      if (tables.next()) {
        System.out.println("Found table " + tableName);
        conn.close();
        return true;
      }
      conn.close();
      return false;
    }catch(SQLException e){
      throw new RuntimeException(e);

    }
  }
//
//  @Override
//  public ConcurrentHashIndex readPatternIndex(String dir){
//    //dir parameter is not used!
//    try{
//      Connection conn = SQLConnection.getConnection();
//      //Map> pats = new ConcurrentHashMap>();
//      String query = "Select index from " + patternindicesTable + " where tablename=\'" + tableName + "\'";
//      Statement stmt = conn.createStatement();
//      ResultSet rs = stmt.executeQuery(query);
//      ConcurrentHashIndex index = null;
//      if(rs.next()){
//        byte[] st = (byte[]) rs.getObject(1);
//        ByteArrayInputStream baip = new ByteArrayInputStream(st);
//        ObjectInputStream ois = new ObjectInputStream(baip);
//        index  = (ConcurrentHashIndex) ois.readObject();
//      }
//      assert index != null;
//      return index;
//    }catch(SQLException e){
//      throw new RuntimeException(e);
//    } catch (ClassNotFoundException e) {
//      throw new RuntimeException(e);
//    } catch (IOException e) {
//      throw new RuntimeException(e);
//    }
//  }
//
//  @Override
//  public void savePatternIndex(ConcurrentHashIndex index, String file) {
//    try {
//      createUpsertFunctionPatternIndex();
//      Connection conn = SQLConnection.getConnection();
//      PreparedStatement  st = conn.prepareStatement("select upsert_patternindex(?,?)");
//      st.setString(1,tableName);
//      ByteArrayOutputStream baos = new ByteArrayOutputStream();
//      ObjectOutputStream oos = new ObjectOutputStream(baos);
//      oos.writeObject(index);
//      byte[] patsAsBytes = baos.toByteArray();
//      ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes);
//      st.setBinaryStream(2, bais, patsAsBytes.length);
//      st.execute();
//      st.close();
//      conn.close();
//      System.out.println("Saved the pattern hash index for " + tableName + " in DB table " + patternindicesTable);
//    }catch (SQLException e){
//      throw new RuntimeException(e);
//    } catch (IOException e) {
//      throw new RuntimeException(e);
//    }
//  }

  //batch processing below is copied from Java Ranch
  public static final int SINGLE_BATCH = 1;
  public static final int SMALL_BATCH = 4;
  public static final int MEDIUM_BATCH = 11;
  public static final int LARGE_BATCH = 51;

  //TODO: make this into an iterator!!
  @Override
  public Map>> getPatternsForAllTokens(Collection sampledSentIds) {
    try{
      Map>> pats = new HashMap<>();
      Connection conn = SQLConnection.getConnection();
      Iterator iter = sampledSentIds.iterator();
      int totalNumberOfValuesLeftToBatch = sampledSentIds.size();
      while ( totalNumberOfValuesLeftToBatch > 0 ) {

        int batchSize = SINGLE_BATCH;
        if (totalNumberOfValuesLeftToBatch >= LARGE_BATCH) {
          batchSize = LARGE_BATCH;
        } else if (totalNumberOfValuesLeftToBatch >= MEDIUM_BATCH) {
          batchSize = MEDIUM_BATCH;
        } else if (totalNumberOfValuesLeftToBatch >= SMALL_BATCH) {
          batchSize = SMALL_BATCH;
        }
        totalNumberOfValuesLeftToBatch -= batchSize;


        StringBuilder inClause = new StringBuilder();

        for (int i = 0; i < batchSize; i++) {
          inClause.append('?');
          if (i != batchSize - 1) {
            inClause.append(',');
          }
        }
        PreparedStatement stmt = conn.prepareStatement(
          "select sentid, patterns from " + tableName + " where sentid in (" + inClause.toString() + ")");
        for (int i=0; i < batchSize && iter.hasNext(); i++) {
          stmt.setString(i+1, iter.next());  // or whatever values you are trying to query by
        }
        stmt.execute();
        ResultSet rs = stmt.getResultSet();
        while(rs.next()){
          String sentid = rs.getString(1);
          byte[] st = (byte[]) rs.getObject(2);
          ByteArrayInputStream baip = new ByteArrayInputStream(st);
          ObjectInputStream ois = new ObjectInputStream(baip);
          pats.put(sentid, (Map>) ois.readObject());
        }

      }
      conn.close();
      return pats;
    }catch(SQLException e){
      throw new RuntimeException(e);
    } catch (ClassNotFoundException e) {
      throw new RuntimeException(e);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  @Override
  public void close() {
    //nothing to do
  }

  @Override
  public void load(String allPatternsDir) {
    //nothing to do
  }

  @Override
  public int size(){
    //TODO: NOT IMPLEMENTED
    return Integer.MAX_VALUE;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy