All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.python.ServerUtils Maven / Gradle / Ivy

Go to download

Integration with CPython for Weka. Python version 2.7.x or higher is required. Also requires the following packages to be installed in python: numpy, pandas, matplotlib and scikit-learn. This package provides a wrapper classifier and clusterer that, between them, cover 60+ scikit-learn algorithms. It also provides a general scripting step for the Knowlege Flow along with scripting plugin environments for the Explorer and Knowledge Flow.

The newest version!
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    ServerUtils.java
 *    Copyright (C) 2015 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.python;

import java.awt.image.BufferedImage;
import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.StringReader;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.imageio.ImageIO;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.module.paramnames.ParameterNamesModule;
import org.apache.commons.codec.binary.Base64;

import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.WekaException;
import weka.core.converters.CSVSaver;
import weka.gui.Logger;

import static com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility.ANY;
import static com.fasterxml.jackson.annotation.PropertyAccessor.FIELD;

/**
 * Contains routines for getting data in and out of python.
 *
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: $
 */
public class ServerUtils {

  public static final ObjectMapper MAPPER = new ObjectMapper() {
    {
      this.registerModule(new ParameterNamesModule());
      this.setVisibility(FIELD, ANY);
      this.setSerializationInclusion(JsonInclude.Include.NON_NULL);
      this.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
      this.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false);
      this.configure(SerializationFeature.WRITE_DURATIONS_AS_TIMESTAMPS, false);
    }
  };

  /**
   * Create a simple header definition to transfer as json to the server
   *
   * @param header the Instances header to convert
   * @param frameName the name of the pandas data frame that this header will
   *          refer to
   * @return A map with key values that define the header
   */
  protected static Map createSimpleHeader(Instances header,
    String frameName) {
    Map result = new HashMap();
    result.put("relation_name", header.relationName());
    result.put("frame_name", frameName);
    result.put("class_index", header.classIndex());
    if (header.classIndex() >= 0) {
      result.put("class_name", header.classAttribute().name());
      if (header.classAttribute().isNominal()) {
        List classVals = new ArrayList();
        for (int i = 0; i < header.classAttribute().numValues(); i++) {
          classVals.add(header.classAttribute().value(i));
        }
        result.put("class_values", classVals);
      }
    }

    if (header.checkForAttributeType(Attribute.DATE)) {
      List dateAtts = new ArrayList();
      for (int i = 0; i < header.numAttributes(); i++) {
        if (header.attribute(i).isDate()) {
          dateAtts.add(i);
        }
      }
      result.put("date_atts", dateAtts);
    }

    return result;
  }

  /**
   * Converts a header definition read and decoded from json into a structure
   * only set of Instances
   *
   * @param header the map containing the definition of the header
   * @return an Instances object
   */
  @SuppressWarnings("unchecked")
  protected static Instances jsonToInstancesHeader(Map header) {
    String relationName = header.get("relation_name").toString();
    List> attributes =
      (List>) header.get("attributes");
    if (attributes == null) {
      throw new IllegalStateException("No attributes in header map!");
    }

    ArrayList atts = new ArrayList();
    for (Map a : attributes) {
      String attName = a.get("name").toString();
      String type = a.get("type").toString();
      String format = null;
      List values = null;
      if (a.get("format") != null) {
        format = a.get("format").toString();
      }
      if (a.get("values") != null) {
        values = (List) a.get("values");
      }

      if (type.equals("NUMERIC")) {
        atts.add(new Attribute(attName));
      } else if (type.equals("DATE")) {
        atts.add(new Attribute(attName, format));
      } else if (type.equals("NOMINAL")) {
        atts.add(new Attribute(attName, values));
      } else if (type.equals("STRING")) {
        atts.add(new Attribute(attName, (List) null));
      }
    }

    return new Instances(relationName, atts, 0);
  }

  /**
   * Send a shutdown command to the micro server
   *
   * @param outputStream the output stream to write the command to
   * @throws WekaException if a problem occurs
   */
  protected static void sendServerShutdown(OutputStream outputStream)
    throws WekaException {
    Map command = new HashMap();
    command.put("command", "shutdown");
    try {
      byte[] bytes = MAPPER.writeValueAsBytes(command);
      // write the command
      writeDelimitedToOutputStream(bytes, outputStream);
    } catch (IOException ex) {
      throw new WekaException(ex);
    }
  }

  /**
   * Execute a script on the server
   * 
   * @param script the script to execute
   * @param outputStream the output stream to write data to the server
   * @param inputStream the input stream to read responses from
   * @param log optional log to write to
   * @param debug true to output debugging info from both java and the server
   * @return a two element list that contains the sys out and sys error from the
   *         script execution
   * @throws WekaException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static List executeUserScript(String script,
    OutputStream outputStream, InputStream inputStream, Logger log,
    boolean debug) throws WekaException {
    if (!script.endsWith("\n")) {
      script += "\n";
    }
    List outAndErr = new ArrayList();

    Map command = new HashMap();
    command.put("command", "execute_script");
    command.put("script", script);
    command.put("debug", debug);
    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }
        writeDelimitedToOutputStream(bytes, outputStream);

        // get the result of execution
        bytes = readDelimitedFromInputStream(inputStream);
        Map ack =
          MAPPER.readValue(bytes, new TypeReference>() {
          });
        // mapper.readValue(bytes, Map.class);
        if (!ack.get("response").toString().equals("ok")) {
          // fatal error
          throw new WekaException(ack.get("error_message").toString());
        }
        // get the script out and err
        outAndErr.add(ack.get("script_out").toString());
        outAndErr.add(ack.get("script_error").toString());
        if (debug) {
          if (log != null) {
            log.logMessage("Script output:\n" + outAndErr.get(0));
            log.logMessage("\nScript error:\n" + outAndErr.get(1));
          } else {
            System.err.println("Script output:\n" + outAndErr.get(0));
            System.err.println("\nScript error:\n" + outAndErr.get(1));
          }
        }

        if (outAndErr.get(1).contains("Warning:")) {
          // clear warnings - we really just want to know if there
          // are major errors
          outAndErr.set(1, "");
        }
      } catch (IOException ex) {
        throw new WekaException(ex);
      }
    } else if (debug) {
      outputCommandDebug(command, log);
    }

    return outAndErr;
  }

  /**
   * Sends instances to a pandas dataframe in python. Assumes data has been
   * binarized (as scikit-learn algorithms take only numeric data) and have had
   * missing values replaced. Creates up to two numpy arrays in python called X
   * and Y: input columns and target (if class is set) respectively
   *
   * @param instances the instances to transfer
   * @param frameName the name of the pandas dataframe
   * @param outputStream the output stream to write to
   * @param inputStream the input stream to listen for server response on
   * @param log the log (if any) to use
   * @param debug true if debugging info is to be output
   * @throws WekaException if a problem occurs
   */
  protected static void sendInstancesScikitLearn(Instances instances,
    String frameName, OutputStream outputStream, InputStream inputStream,
    Logger log, boolean debug) throws WekaException {
    // iris.iloc[:,[0,2,4]] (slice columns to array)
    // pd.get_dummies(iris) (binarize/one hot)

    // Assumes that data has had nominals (except the class) converted
    // to binary indicators and all missing values replaced
    // ObjectMapper mapper = JsonFactory.create();
    Map simpleHeader = createSimpleHeader(instances, frameName);
    Map command = new HashMap();
    command.put("command", "put_instances");
    command.put("num_instances", instances.numInstances());
    command.put("header", simpleHeader);
    command.put("debug", debug);
    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }
        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);
        if (instances.numInstances() > 0) {
          StringBuilder builder = new StringBuilder();

          // header row
          for (int i = 0; i < instances.numAttributes(); i++) {
            builder.append(Utils.quote(instances.attribute(i).name()));
            if (i < instances.numAttributes() - 1) {
              builder.append(",");
            } else {
              builder.append("\n");
            }
          }

          // instances
          for (int i = 0; i < instances.numInstances(); i++) {
            Instance current = instances.instance(i);
            for (int j = 0; j < instances.numAttributes(); j++) {
              builder.append(current.value(j));
              if (j < instances.numAttributes() - 1) {
                builder.append(",");
              } else {
                builder.append("\n");
              }
            }
          }
          if (debug) {
            System.err.println(builder.toString());
          }
          ByteArrayOutputStream bos = new ByteArrayOutputStream();
          BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(bos));
          bw.write(builder.toString());
          bw.flush();
          bytes = bos.toByteArray();
          writeDelimitedToOutputStream(bytes, outputStream);

          String serverAck = receiveServerAck(inputStream);
          if (serverAck != null) {
            throw new WekaException("Transfer of instances failed: "
              + serverAck);
          }

          // execute script to create X (and Y) arrays
          int classIndex = instances.classIndex();
          builder = new StringBuilder();
          for (int i = 0; i < instances.numAttributes(); i++) {
            if (i != classIndex) {
              builder.append(i).append(",");
            }
          }
          String xList = builder.substring(0, builder.length() - 1);
          builder = new StringBuilder();
          builder.append("X = " + frameName + ".iloc[:,[" + xList
            + "]].values\n");
          if (classIndex >= 0) {
            builder.append("Y = " + frameName + ".iloc[:,[" + classIndex
              + "]].values\n");
          }

          if (debug) {
            if (log != null) {
              log.logMessage("Executing python script:\n\n"
                + builder.toString());
            } else {
              System.err.println("Executing python script:\n\n"
                + builder.toString());
            }
          }

          List outErr =
            executeUserScript(builder.toString(), outputStream, inputStream,
              log, debug);

          if (outErr.size() == 2 && outErr.get(1).length() > 0) {
            throw new WekaException(outErr.get(1));
          }
        }
      } catch (IOException e) {
        throw new WekaException(e);
      }
    } else if (debug) {
      outputCommandDebug(command, log);
    }
  }

  /**
   * Sends instances to a pandas dataframe in python.
   *
   * @param instances the instances to transfer
   * @param frameName the name of the data frame to create in python
   * @param outputStream the output stream to write to the server
   * @param inputStream the input stream to get server responses from
   * @param log optional log
   * @param debug true if debugging info is to be output
   * @throws WekaException if a problem occurs
   */
  protected static void sendInstances(Instances instances, String frameName,
    OutputStream outputStream, InputStream inputStream, Logger log,
    boolean debug) throws WekaException {

    Map simpleHeader = createSimpleHeader(instances, frameName);
    Map command = new HashMap();
    command.put("command", "put_instances");
    command.put("num_instances", instances.numInstances());
    command.put("header", simpleHeader);
    command.put("debug", debug);

    if (instances.checkForAttributeType(Attribute.DATE)) {
      // ensure a single, consistent date format
      ArrayList newAtts = new ArrayList();
      for (int i = 0; i < instances.numAttributes(); i++) {
        if (!instances.attribute(i).isDate()) {
          newAtts.add((Attribute) instances.attribute(i).copy());
        } else {
          Attribute newDate =
            new Attribute(instances.attribute(i).name(), "yyyy-MM-dd HH:mm:ss");
          newAtts.add(newDate);
        }
      }
      Instances newInsts =
        new Instances(instances.relationName(), newAtts,
          instances.numInstances());
      for (int i = 0; i < instances.numInstances(); i++) {
        newInsts.add(instances.instance(i));
      }
      newInsts.setClassIndex(instances.classIndex());
      instances = newInsts;
    }

    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }

        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);

        // write instances as CSV
        if (instances.numInstances() > 0) {
          CSVSaver saver = new CSVSaver();
          saver.setInstances(instances);
          ByteArrayOutputStream bos = new ByteArrayOutputStream();
          saver.setDestination(bos);
          saver.writeBatch();
          bytes = bos.toByteArray();
          writeDelimitedToOutputStream(bytes, outputStream);
        }

        String serverAck = receiveServerAck(inputStream);
        if (serverAck != null) {
          throw new WekaException("Transfer of instances failed: " + serverAck);
        }
      } catch (IOException e) {
        throw new WekaException(e);
      }
    } else if (debug) {
      outputCommandDebug(command, log);
    }
  }

  /**
   * Recieve a simple ack from the server. Returns a non-null string if the ack
   * received contains an error message
   *
   * @param inputStream the input stream to read the ack from
   * @return a non-null string if there was an error returned by the server
   * @throws IOException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static String receiveServerAck(InputStream inputStream)
    throws IOException {
    byte[] bytes = readDelimitedFromInputStream(inputStream);
    Map ack =
      MAPPER.readValue(bytes, new TypeReference>() {
      });

    String response = ack.get("response").toString();
    if (response.equals("ok")) {
      return null;
    }

    return ack.get("error_message").toString();
  }

  /**
   * Receives a PID ack from the server
   *
   * @param inputStream the input stream to read from
   * @return the process ID of the server
   * @throws IOException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static int receiveServerPIDAck(InputStream inputStream)
    throws IOException {
    byte[] bytes = readDelimitedFromInputStream(inputStream);
    Map ack =
      MAPPER.readValue(bytes, new TypeReference>() {
      });

    String response = ack.get("response").toString();
    if (response.equals("pid_response")) {
      return (Integer) ack.get("pid");
    } else {
      throw new IOException("Server did not send a pid_response");
    }
  }

  /**
   * Receives the current list of variables from Python
   *
   * @param outputStream the output stream to
   * @param inputStream the input stream to read from
   * @param log the log to use
   * @param debug true if debugging info is to be output
   * @return a list of variables set in python. Each entry in the list is a two
   *         element array, where the first element holds the variable name and
   *         the second the python type
   * @throws WekaException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static List receiveVariableList(
    OutputStream outputStream, InputStream inputStream, Logger log,
    boolean debug) throws WekaException {

    List results = new ArrayList();
    Map command = new HashMap();
    command.put("command", "get_variable_list");
    command.put("debug", debug);
    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }

        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);

        bytes = readDelimitedFromInputStream(inputStream);
        Map ack =
          MAPPER.readValue(bytes, new TypeReference>() {
          });
        if (!ack.get("response").toString().equals("ok")) {
          // fatal error
          throw new WekaException(ack.get("error_message").toString());
        }

        Object l = ack.get("variable_list");
        if (!(l instanceof List)) {
          throw new WekaException(
            "Was expecting the variable list to be a List " + "object!");
        }
        List> vList = (List>) l;
        for (Map v : vList) {
          String[] vEntry = new String[2];
          vEntry[0] = v.get("name");
          vEntry[1] = v.get("type");
          results.add(vEntry);
        }
      } catch (IOException ex) {
        throw new WekaException(ex);
      }
    } else if (debug) {
      outputCommandDebug(command, log);
    }

    return results;
  }

  /**
   * Receive the value of a variable in python in json form
   *
   * @param varName the name of the variable to get from python
   * @param outputStream the output stream to write to
   * @param inputStream the input stream to get server responses from
   * @param log optional log
   * @param debug true if debugging info is to be output
   * @return the value of the variable in json form
   * @throws WekaException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static Object receiveJsonVariableValue(String varName,
    OutputStream outputStream, InputStream inputStream, Logger log,
    boolean debug) throws WekaException {

    Object variableValue = "";
    Map command = new HashMap();
    command.put("command", "get_variable_value");
    command.put("variable_name", varName);
    command.put("variable_encoding", "json");
    command.put("debug", debug);
    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }

        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);

        bytes = readDelimitedFromInputStream(inputStream);
        Map ack =
          MAPPER.readValue(bytes, new TypeReference>() {
          });
        if (!ack.get("response").toString().equals("ok")) {
          // fatal error
          throw new WekaException(ack.get("error_message").toString());
        }
        if (!ack.get("variable_name").toString().equals(varName)) {
          throw new WekaException("Server sent back a value for a different "
            + "variable!");
        }
        if (!ack.get("variable_encoding").toString().equals("json")) {
          throw new WekaException("Encoding of variable value received from "
            + "server is not Json!");
        }
        variableValue = ack.get("variable_value");
      } catch (IOException ex) {
        throw new WekaException(ex);
      }
    } else if (debug) {
      outputCommandDebug(command, log);
    }

    return variableValue;
  }

  /**
   * Receive the value of a variable in pickled or plain string form. If getting
   * a pickled variable, then in python 2 this is the pickled string; in python
   * 3 pickle.dumps returns a byte object, so the value is converted to base64
   * before leaving the server.
   *
   * @param varName the name of the variable to get from the server
   * @param outputStream the output stream to write to
   * @param inputStream the input stream to get server responses from
   * @param plainString true if the plain string form of the variable is to be
   *          returned; otherwise the variable value is pickled (and further
   *          encoded to base64 in the case of python 3)
   * @param log optional log
   * @param debug true if debugging info is to be output
   * @return the variable value
   * @throws WekaException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static String receivePickledVariableValue(String varName,
    OutputStream outputStream, InputStream inputStream, boolean plainString,
    Logger log, boolean debug) throws WekaException {

    String objectValue = "";
    Map command = new HashMap();
    command.put("command", "get_variable_value");
    command.put("variable_name", varName);
    command.put("variable_encoding", plainString ? "string" : "pickled");
    command.put("debug", debug);

    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }

        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);

        bytes = readDelimitedFromInputStream(inputStream);
        Map ack =
          MAPPER.readValue(bytes, new TypeReference>() {
          });
        if (!ack.get("response").toString().equals("ok")) {
          // fatal error
          throw new WekaException(ack.get("error_message").toString());
        }
        if (!ack.get("variable_name").toString().equals(varName)) {
          throw new WekaException("Server sent back a value for a different "
            + "variable!");
        }
        objectValue = ack.get("variable_value").toString();
      } catch (IOException ex) {
        throw new WekaException(ex);
      }
    } else if (debug) {
      outputCommandDebug(command, log);
    }

    return objectValue;
  }

  /**
   * Std out and err are redirected to StringIO objects in the server. This
   * method retrieves the values of those buffers.
   *
   * @param outputStream the output stream to talk to the server on
   * @param inputStream the input stream to receive server responses from
   * @param log optional log
   * @param debug true to output debugging info
   * @return the std out and err strings as a two element list
   * @throws WekaException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static List receiveDebugBuffer(OutputStream outputStream,
    InputStream inputStream, Logger log, boolean debug) throws WekaException {
    List stdOutStdErr = new ArrayList();

    Map command = new HashMap();
    command.put("command", "get_debug_buffer");
    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }

        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);

        bytes = readDelimitedFromInputStream(inputStream);
        Map ack =
          MAPPER.readValue(bytes, new TypeReference>() {
          });
        if (!ack.get("response").toString().equals("ok")) {
          // fatal error
          throw new WekaException(ack.get("error_message").toString());
        }
        Object stOut = ack.get("std_out");
        stdOutStdErr.add(stOut != null ? stOut.toString() : "");
        Object stdErr = ack.get("std_err");
        stdOutStdErr.add(stdErr != null ? stdErr.toString() : "");
      } catch (IOException ex) {
        throw new WekaException(ex);
      }
    } else if (debug) {
      outputCommandDebug(command, log);
    }

    return stdOutStdErr;
  }

  /**
   * Send the pickled (and possibly base64 encoded) value of a variable to the
   * server. Sets the decoded value in python.
   *
   * @param varName the name of the variable to be set in python
   * @param varValue the value of the variable
   * @param outputStream the output stream to talk to the server on
   * @param inputStream the input stream to receive responses on
   * @param log optional log
   * @param debug true if debugging info is to be output
   * @throws WekaException if a problem occurs
   */
  protected static void sendPickledVariableValue(String varName,
    String varValue, OutputStream outputStream, InputStream inputStream,
    Logger log, boolean debug) throws WekaException {

    Map command = new HashMap();
    command.put("command", "set_variable_value");
    command.put("variable_name", varName);
    command.put("variable_encoding", "pickled");
    command.put("variable_value", varValue);
    command.put("debug", debug);
    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }
        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);

        String serverAck = receiveServerAck(inputStream);
        if (serverAck != null) {
          throw new WekaException(serverAck);
        }
      } catch (IOException ex) {
        throw new WekaException(ex);
      }
    } else if (debug) {
      outputCommandDebug(command, log);
    }
  }

  /**
   * Get an image from python. Assumes that the image is a
   * matplotlib.figure.Figure object. Retrieves this as png data and returns a
   * BufferedImage.
   *
   * @param varName the name of the variable containing the image in python
   * @param outputStream the output stream to talk to the server on
   * @param inputStream the input stream to receive server responses from
   * @param log an optional log
   * @param debug true to output debug info
   * @return a BufferedImage
   * @throws WekaException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static BufferedImage getPNGImageFromPython(String varName,
    OutputStream outputStream, InputStream inputStream, Logger log,
    boolean debug) throws WekaException {

    Map command = new HashMap();
    command.put("command", "get_image");
    command.put("variable_name", varName);
    command.put("debug", debug);

    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }
        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);

        bytes = readDelimitedFromInputStream(inputStream);
        Map ack =
          MAPPER.readValue(bytes, new TypeReference>() {
          });
        if (!ack.get("response").toString().equals("ok")) {
          // fatal error
          throw new WekaException(ack.get("error_message").toString());
        }

        if (!ack.get("variable_name").toString().equals(varName)) {
          throw new WekaException(
            "Server sent back a response for a different " + "variable!");
        }

        String encoding = ack.get("encoding").toString();
        String imageData = ack.get("image_data").toString();
        byte[] imageBytes;
        if (encoding.equals("base64")) {
          imageBytes = Base64.decodeBase64(imageData.getBytes());
        } else {
          imageBytes = imageData.getBytes();
        }
        return ImageIO.read(new BufferedInputStream(new ByteArrayInputStream(
          imageBytes)));
      } catch (IOException ex) {
        throw new WekaException(ex);
      }
    } else {
      outputCommandDebug(command, log);
    }
    return null;
  }

  /**
   * Get the type of a variable in python
   *
   * @param varName the name of the variable to check
   * @param outputStream the output stream to talk to the server on
   * @param inputStream the input stream to receive server responses from
   * @param log an optional log
   * @param debug true to output debug info
   * @return the type of the variable in python
   * @throws WekaException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static PythonSession.PythonVariableType getPythonVariableType(
    String varName, OutputStream outputStream, InputStream inputStream,
    Logger log, boolean debug) throws WekaException {

    Map command = new HashMap();
    command.put("command", "get_variable_type");
    command.put("variable_name", varName);
    command.put("debug", debug);

    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }
        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);

        bytes = readDelimitedFromInputStream(inputStream);
        Map ack =
          MAPPER.readValue(bytes, new TypeReference>() {
          });
        if (!ack.get("response").toString().equals("ok")) {
          // fatal error
          throw new WekaException(ack.get("error_message").toString());
        }

        if (!ack.get("variable_name").toString().equals(varName)) {
          throw new WekaException(
            "Server sent back a response for a different " + "variable!");
        }

        String varType = ack.get("type").toString();
        PythonSession.PythonVariableType pvt =
          PythonSession.PythonVariableType.Unknown;
        for (PythonSession.PythonVariableType t : PythonSession.PythonVariableType
          .values()) {
          if (t.toString().toLowerCase().equals(varType)) {
            pvt = t;
            break;
          }
        }
        return pvt;
      } catch (IOException ex) {
        throw new WekaException(ex);
      }
    } else {
      outputCommandDebug(command, log);
    }

    return PythonSession.PythonVariableType.Unknown;
  }

  /**
   * Check if a named variable has a value in python
   *
   * @param varName the name of the variable to check
   * @param outputStream the output stream to talk to the server on
   * @param inputStream the input stream to receive server responses from
   * @param log an optional log
   * @param debug true to output debug info
   * @return true if the named variable is set in python
   * @throws WekaException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static boolean checkIfPythonVariableIsSet(String varName,
    OutputStream outputStream, InputStream inputStream, Logger log,
    boolean debug) throws WekaException {

    Map command = new HashMap();
    command.put("command", "variable_is_set");
    command.put("variable_name", varName);
    command.put("debug", debug);

    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }
        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);

        bytes = readDelimitedFromInputStream(inputStream);
        Map ack =
          MAPPER.readValue(bytes, new TypeReference>() {
          });
        if (!ack.get("response").toString().equals("ok")) {
          // fatal error
          throw new WekaException(ack.get("error_message").toString());
        }

        if (!ack.get("variable_name").toString().equals(varName)) {
          throw new WekaException(
            "Server sent back a response for a different " + "variable!");
        }

        return (Boolean) ack.get("variable_exists");
      } catch (IOException ex) {
        throw new WekaException(ex);
      }
    } else if (debug) {
      outputCommandDebug(command, log);
    }

    // TODO
    return true;
  }

  /**
   * Retrieve a pandas data frame from python. The server sends the header
   * information in json form followed by CSV data (without a header row).
   *
   * @param frameName the name of the pandas data frame to get from the server
   * @param outputStream the output stream to talk to the server on
   * @param inputStream the input stream to receive responses on
   * @param log optional log
   * @param debug true if debugging info is to be output
   * @return the panadas data frame as a set of Instances
   * @throws WekaException if a problem occurs
   */
  @SuppressWarnings("unchecked")
  protected static Instances receiveInstances(String frameName,
    OutputStream outputStream, InputStream inputStream, Logger log,
    boolean debug) throws WekaException {

    Map command = new HashMap();
    command.put("command", "get_instances");
    command.put("frame_name", frameName);
    command.put("debug", debug);

    if (inputStream != null && outputStream != null) {
      try {
        byte[] bytes = MAPPER.writeValueAsBytes(command);
        if (debug) {
          outputCommandDebug(command, log);
        }

        // write the command
        writeDelimitedToOutputStream(bytes, outputStream);
        String serverAck = receiveServerAck(inputStream);
        if (serverAck != null) {
          throw new WekaException(serverAck);
        }

        // read the header
        bytes = readDelimitedFromInputStream(inputStream);
        Map headerResponse =
          MAPPER.readValue(bytes, new TypeReference>() {
          });
        if (headerResponse == null) {
          throw new WekaException("Map is null!");
        }
        if (headerResponse.get("response").toString()
          .equals("instances_header")) {
          if (debug) {
            if (log != null) {
              log.logMessage("Received header response command with "
                + headerResponse.get("num_instances") + " instances");
            } else {
              System.err.println("Received header response command with "
                + headerResponse.get("num_instances") + " instances");
            }
          }
        } else {
          throw new WekaException("Unknown response type from server");
        }

        Instances header =
          jsonToInstancesHeader((Map) headerResponse
            .get("header"));

        // receive the CSV data, append with header, and then create
        // instances
        bytes = readDelimitedFromInputStream(inputStream);
        String CSV = new String(bytes);
        StringBuilder b = new StringBuilder();
        b.append(header.toString()).append("\n");
        b.append(CSV).append("\n");

        return new Instances(new StringReader(b.toString()));
      } catch (IOException ex) {
        throw new WekaException(ex);
      }
    } else if (debug) {
      outputCommandDebug(command, log);
    }
    return null;
  }

  /**
   * Write length delimited data to the output stream
   *
   * @param bytes the bytes to write
   * @param outputStream the output stream to write to
   * @throws IOException if a problem occurs
   */
  protected static void writeDelimitedToOutputStream(byte[] bytes,
    OutputStream outputStream) throws IOException {

    // write the message length as a fixed size integer
    outputStream.write(ByteBuffer.allocate(4).putInt(bytes.length).array());

    // write the message itself
    outputStream.write(bytes);
  }

  /**
   * Read length delimited data from the input stream
   *
   * @param inputStream the input stream to read from
   * @return the bytes read
   * @throws IOException if a problem occurs
   */
  protected static byte[] readDelimitedFromInputStream(InputStream inputStream)
    throws IOException {
    byte[] sizeBytes = new byte[4];
    int numRead = inputStream.read(sizeBytes, 0, 4);
    if (numRead < 4) {
      throw new IOException(
        "Failed to read the message size from the input stream! Num bytes read: "
          + numRead);
    }

    int messageLength = ByteBuffer.wrap(sizeBytes).getInt();
    byte[] messageData = new byte[messageLength];
    // for (numRead = 0; numRead < messageLength; numRead +=
    // inputStream.read(messageData, numRead, messageLength - numRead));
    for (numRead = 0; numRead < messageLength;) {
      int currentNumRead =
        inputStream.read(messageData, numRead, messageLength - numRead);
      if (currentNumRead < 0) {
        throw new IOException("Unexpected end of stream!");
      }
      numRead += currentNumRead;
    }

    return messageData;
  }

  /**
   * Prints out a json command for debugging purposes
   *
   * @param command the command to print out
   * @param log optional log
   */
  protected static void outputCommandDebug(Map command,
    Logger log) {
    try {
      String serialized = MAPPER.writeValueAsString(command);
      if (log != null) {
        log.logMessage("Sending command:\n" + serialized);
      } else {
        System.err.println("Sending command: ");
        String indented =
          MAPPER.writerWithDefaultPrettyPrinter()
            .writeValueAsString(serialized);
        System.out.println(indented);
      }
    } catch (JsonProcessingException e) {
      e.printStackTrace();
    }
  }

  public static void main(String[] args) {
    try {
      Instances insts = new Instances(new FileReader(args[0]));
      insts.setClassIndex(insts.numAttributes() - 1);
      sendInstances(insts, "test", null, null, null, false);
    } catch (Exception ex) {
      ex.printStackTrace();
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy