All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.text.gpt2.GPT2Model Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.text.gpt2;

import com.google.protobuf.InvalidProtocolBufferException;
import com.simiacryptus.text.GraphModifier;
import com.simiacryptus.text.LanguageCodeModel;
import com.simiacryptus.text.TextGenerator;
import org.apache.commons.io.FileUtils;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.*;

import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.*;
import java.util.function.BiFunction;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;

public class GPT2Model implements LanguageCodeModel {
  protected static final Logger logger = LoggerFactory.getLogger(GPT2Model.class);

  public final String name;
  protected final byte[] graphDef;
  protected final ArrayList code_history = new ArrayList<>();
  protected final GraphModifier graphModifier;
  protected final GPT2Codec codec;
  public HashSet loadedSubnets;
  public Graph graph;
  public Session session;
  protected int history_size = 0;
  protected Tensor tensor_state = null;
  private BiFunction filterFn = (a, b) -> true;

  public GPT2Model(String name, GraphModifier graphModifier, File file, GPT2Codec codec) {
    this(name, loadModel(file), graphModifier, codec);
  }

  public GPT2Model(String name, byte[] graphDef, GraphModifier graphModifier, GPT2Codec codec) {
    this(name, graphDef, graphModifier, codec, new Graph());
  }

  public GPT2Model(String name, byte[] graphDef, GraphModifier graphModifier, GPT2Codec codec, Graph graph) {
    this(name, graphDef, graphModifier, codec, graph, new Session(graph, ConfigProto.newBuilder()
        //.setLogDevicePlacement(true)
//        .setUsePerSessionThreads(true)
//        .setInterOpParallelismThreads(8)
//        .setIntraOpParallelismThreads(8)
//        .setIsolateSessionState(false)
        .setGraphOptions(GraphOptions.newBuilder()
            .setOptimizerOptions(OptimizerOptions.newBuilder()
                .setDoConstantFolding(true)
                .setDoFunctionInlining(true)
                .setDoCommonSubexpressionElimination(true)
                .build())
            .build())
        .setGpuOptions(GPUOptions.newBuilder()
            .setForceGpuCompatible(true)
            .setAllowGrowth(true)
            .setPerProcessGpuMemoryFraction(0.5)
            .build())
        .build().toByteArray()));
  }

  public GPT2Model(String name, byte[] graphDef, GraphModifier graphModifier, GPT2Codec codec, Graph graph, Session session) {
    this.name = name;
    this.graphDef = graphDef;
    this.graphModifier = graphModifier;
    this.codec = codec;
    this.graph = graph;
    this.session = session;
    loadedSubnets = new HashSet<>();
  }

  public static byte[] loadModel(File file) {
    try {
      return FileUtils.readFileToByteArray(file);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  @NotNull
  public static Tensor copy(Tensor toCopy) {
    FloatBuffer floatBuffer = FloatBuffer.allocate(toCopy.numElements());
    toCopy.writeTo(floatBuffer);
    floatBuffer.flip();
    return Tensor.create(toCopy.shape(), floatBuffer);
  }

  @Override
  public LanguageCodeModel copy() {
    GPT2Model copy = new GPT2Model(name, graphDef, graphModifier, this.codec, this.graph, this.session);
    if (null == this.tensor_state) {
      copy.tensor_state = null;
    } else {
      copy.tensor_state = copy(this.tensor_state);
    }
    copy.history_size = this.history_size;
    copy.loadedSubnets = this.loadedSubnets;
    copy.code_history.addAll(this.code_history);
    copy.filterFn = this.filterFn;
    return copy;
  }

  public float[] logitsToProbabilities(float[] logits) {
    String prefix = codec.decode(code_history.stream().toArray(i -> new Integer[i]));
    int[] sortedIndices = Arrays.stream(TextGenerator.sortedIndices(logits, Integer.MAX_VALUE))
        .filter(item -> {
          if (item == logits.length - 1) return true;
          String thisStr = codec.decode(item);
          return getFilterFn().apply(prefix, thisStr);
        })
        .toArray();
    double[] input = IntStream.range(0, sortedIndices.length).mapToDouble(c -> logits[sortedIndices[c]]).toArray();
    assert 1 < input.length : "input.length() = " + input.length;

    @Nullable final double[] exp;
    final DoubleSummaryStatistics summaryStatistics = DoubleStream.of(input).filter(x -> Double.isFinite(x)).summaryStatistics();
    final double max = summaryStatistics.getMax();
    exp = Arrays.stream(input).map(x -> {
      double xx = Math.exp(x - max);
      return Double.isFinite(xx) ? xx : 0;
    }).toArray();
    final double sum = 0 < Arrays.stream(exp).sum() ? Arrays.stream(exp).sum() : 1;
    assert Double.isFinite(sum);
    @Nullable double[] chosen = Arrays.stream(exp).map(x -> x / sum).toArray();

    for (int i = 0; i < logits.length; i++) logits[i] = 0;
    IntStream.range(0, chosen.length).forEach(c -> {
      logits[sortedIndices[c]] = (float) chosen[c];

    });
    return logits;
  }

  @Override
  public synchronized LanguageCodeModel clear() {
    logger.debug("Reset Language Model State");
    if (null != this.tensor_state) this.tensor_state.close();
    this.tensor_state = null;
    history_size = 0;
    code_history.clear();
    return this;
  }

  @Override
  public synchronized float[] eval(int data_X) {
    logger.debug(String.format("Eval %d", data_X));
    try {
      String prefix;
      if (!loadedSubnets.contains("")) {
        loadedSubnets.add("");
        graph.importGraphDef(this.graphDef);
      }
      if (null == this.tensor_state) {
        prefix = "init/";
        if (!loadedSubnets.contains(prefix)) {
          GraphModifier.importGraphDef(graph, this.graphModifier.edit(GraphDef.parseFrom(this.graphDef), prefix, false));
          loadedSubnets.add(prefix);
        }
      } else {
        prefix = "";
      }
      this.code_history.add(data_X);
      final float[] eval;
      if (0 == history_size) {
        eval = eval(prefix, data_X);
      } else {
        final int[] activeCodes = this.code_history
            .subList(this.code_history.size() - 1, this.code_history.size())
            .stream().mapToInt(x -> x).toArray();
        eval = eval(prefix, activeCodes);
      }
      return eval;
    } catch (InvalidProtocolBufferException e) {
      throw new RuntimeException(e);
    }
  }

  public synchronized float[] eval(String prefix, int... data_X) {
    synchronized (session) {
      logger.debug(String.format("Eval(%s,%s)", session, Arrays.toString(data_X)));
      Tensor input_X = Tensor.create(new long[]{1, data_X.length}, IntBuffer.wrap(data_X));
      Session.Runner runner = session.runner().feed("input_X", input_X);
      if (null != this.tensor_state) runner = runner.feed(prefix + "input_past", this.tensor_state);
      logger.debug("Input Codes: " + Arrays.toString(data_X));
      logger.debug("Input State: " + (this.tensor_state == null ? "null" : Arrays.toString(this.tensor_state.shape())));
      final Tensor prevState = this.tensor_state;
      runner = runner
          .fetch(prefix + "output/strided_slice_1")
          .fetch((0 == history_size) ? (prefix + "model/stack") : (prefix + "output/concat"));
      List> run = runner.run();
      Tensor tensor_next = run.get(0).expect(Float.class);
      final Tensor outputState = run.get(1).expect(Float.class); // reshape(shape_state, run.get(1).expect(Float.class));
      logger.debug("Output Logits: " + Arrays.toString(tensor_next.shape()));
      logger.debug("Output State: " + Arrays.toString(outputState.shape()));
      if (null == this.tensor_state) {
        this.history_size = (int) outputState.shape()[4];
        this.tensor_state = outputState;
      } else {
        this.history_size = this.history_size + 1;
        this.tensor_state.close();
        this.tensor_state = outputState;
      }
      float[] logits = new float[tensor_next.numElements()];
      tensor_next.writeTo(FloatBuffer.wrap(logits));
      tensor_next.close();
      if (null != prevState) prevState.close();
      input_X.close();
      return logitsToProbabilities(logits);
    }
  }

  @Override
  public BiFunction getFilterFn() {
    return filterFn;
  }

  @Override
  public LanguageCodeModel setFilterFn(BiFunction filterFn) {
    this.filterFn = filterFn;
    return this;
  }

  @Override
  public Tensor state() {
    return this.tensor_state;
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy