All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.text.gpt2.GPT2Codec Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.text.gpt2;

import com.google.gson.GsonBuilder;
import com.google.gson.JsonObject;
import org.apache.commons.io.FileUtils;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class GPT2Codec {
  protected static final Logger logger = LoggerFactory.getLogger(GPT2Codec.class);

  protected final TreeMap encoder;
  protected final TreeMap decoder;
  private final int vocabSize;

  public GPT2Codec(TreeMap encoder, int vocabSize) {
    this.encoder = encoder;
    this.vocabSize = vocabSize;
    this.decoder = buildDecoder(this.encoder);
  }

  public GPT2Codec(File file, int vocabSize) {
    this(GPT2Codec.loadEncoder(file), vocabSize);
  }

  public static TreeMap buildDecoder(TreeMap encoder) {
    Stream> stream = encoder.entrySet().stream();
    return new TreeMap<>(stream.collect(Collectors.toMap(
        (Map.Entry e) -> e.getValue(),
        (Map.Entry e) -> e.getKey()
    )));
  }

  public static TreeMap loadEncoder(File file) {
    try {
      return toMap(FileUtils.readFileToString(file, "UTF-8"), getCharacterTransformer());
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  @NotNull
  public static TreeMap toMap(String jsonTxt, Function keyEncoder) {
    JsonObject json = new GsonBuilder().create().fromJson(jsonTxt, JsonObject.class);
    return new TreeMap<>(json.keySet().stream().collect(Collectors.toMap(keyEncoder, x -> json.get(x).getAsInt(), (a, b) -> a)));
  }

  @NotNull
  public static Function getCharacterTransformer() {
    Map byteEncoder = byteEncoder();
    return x -> {
      char[] chars = x.toCharArray();
      for (int i = 0; i < chars.length; i++) {
        chars[i] = byteEncoder.getOrDefault(chars[i], chars[i]);
      }
      return new String(chars);
    };
  }

  public static Map byteEncoder() {
    try {
      HashMap characterMap = new HashMap<>();
      for (int c = 0; c < 256; c++) {
        characterMap.put((char) (c + 256), (char) c);
      }
      for (char i = '!'; i < '~'; i++) {
        characterMap.put(i, i);
      }
      for (char i = '¡'; i < '¬'; i++) {
        characterMap.put(i, i);
      }
      for (char i = '®'; i < 'ÿ'; i++) {
        characterMap.put(i, i);
      }
      return characterMap;
    } catch (Throwable e) {
      throw new RuntimeException(e);
    }
  }

  public String decode(Integer... msg) {
    return Arrays.stream(msg).map(i -> decoder.getOrDefault(i, "")).reduce((a, b) -> a + b).orElseGet(() -> "");
  }

  public List encode(String msg) {
    ArrayList list = new ArrayList<>();
    if (null != msg && !msg.isEmpty()) {
      StringBuffer stringBuffer = new StringBuffer(msg);
      while (stringBuffer.length() > 0) {
        Optional codeString = lookup(stringBuffer.toString());
        if (codeString.isPresent()) {
          String key = codeString.get();
          stringBuffer.delete(0, key.length());
          list.add(encoder.get(key));
        } else {
          stringBuffer.delete(0, 1);
        }
      }
    }
    return list;
  }

  protected Optional lookup(String searchStr) {
    if (null == searchStr || searchStr.isEmpty()) return Optional.empty();
    String ceilingKey = encoder.ceilingKey(searchStr);
    String floorKey = encoder.floorKey(searchStr);
    if (null != ceilingKey && !searchStr.startsWith(ceilingKey)) ceilingKey = null;
    if (null != floorKey && !searchStr.startsWith(floorKey)) floorKey = null;
    Optional codeString;
    if (null != ceilingKey || null != floorKey) {
      if (null != ceilingKey && null != floorKey) {
        if (floorKey.length() < ceilingKey.length()) {
          codeString = Optional.of(ceilingKey);
        } else {
          codeString = Optional.of(floorKey);
        }
      } else if (null != ceilingKey) {
        codeString = Optional.of(ceilingKey);
      } else {
        codeString = Optional.of(floorKey);
      }
    } else {
      codeString = Optional.empty();
    }
//        codeString = encoder.keySet().stream()
//            .filter(x -> x.equals(searchStr.substring(0, Math.min(searchStr.length(), x.length()))))
//            .sorted(Comparator.comparing(x -> -x.length()))
//            .findFirst();
    return codeString;
  }

  public int getVocabSize() {
    return vocabSize;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy