All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.lightgbm.LightGBMUtil Maven / Gradle / Ivy

Go to download

Java library and command-line application for converting LightGBM models to PMML

There is a newer version: 1.5.5
Show newest version
/*
 * Copyright (c) 2017 Villu Ruusmann
 *
 * This file is part of JPMML-LightGBM
 *
 * JPMML-LightGBM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-LightGBM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-LightGBM.  If not, see .
 */
package org.jpmml.lightgbm;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.google.common.io.CharStreams;
import org.dmg.pmml.Interval;
import org.jpmml.converter.ValueUtil;

public class LightGBMUtil {

	private LightGBMUtil(){
	}

	static
	public GBDT loadGBDT(InputStream is) throws IOException {
		return loadGBDT(parseText(is));
	}

	static
	public GBDT loadGBDT(Iterator lines){
		List
sections = loadText(lines); GBDT gbdt = new GBDT(); gbdt.load(sections); return gbdt; } static private List
loadText(Iterator lines){ List
sections = new ArrayList<>(); Section section = new Section(); loop: while(lines.hasNext()){ String line = lines.next(); if(("").equals(line)){ if(section.size() > 0){ sections.add(section); section = new Section(); } continue loop; } section.put(line); } if(section.size() > 0){ sections.add(section); } return sections; } static public Iterator parseText(InputStream is) throws IOException { Reader reader = new InputStreamReader(is, "US-ASCII"); List lines = CharStreams.readLines(reader); return lines.iterator(); } static public String[] parseStringArray(String string, int length){ String[] result = string.split("\\s"); if(length > -1 && result.length != length){ throw new IllegalArgumentException("Expected " + length + " elements, got " + result.length + " elements"); } return result; } static public int[] parseIntArray(String string, int length){ String[] values = parseStringArray(string, length); int[] result = new int[values.length]; for(int i = 0; i < result.length; i++){ result[i] = parseInt(values[i]); } return result; } static public long[] parseUnsignedIntArray(String string, int length){ String[] values = parseStringArray(string, length); long[] result = new long[values.length]; for(int i = 0; i < result.length; i++){ result[i] = parseUnsignedInt(values[i]); } return result; } static public double[] parseDoubleArray(String string, int length){ String[] values = parseStringArray(string, length); double[] result = new double[values.length]; for(int i = 0; i < result.length; i++){ result[i] = parseDouble(values[i]); } return result; } static private int parseInt(String string){ return Integer.parseInt(string); } static private long parseUnsignedInt(String string){ return Long.parseLong(string); } static private double parseDouble(String string){ switch(string){ case "inf": return Double.POSITIVE_INFINITY; default: return Double.parseDouble(string); } } static public boolean isNone(String string){ return string.equals("none"); } static public boolean isInterval(String string){ return string.startsWith("[") && string.endsWith("]"); } static public boolean isBinaryInterval(String string){ return string.equals("[0:1]"); } static public boolean isValues(String string){ return !isInterval(string); } static public Interval parseInterval(String string){ if(string.length() < 3){ throw new IllegalArgumentException(); } String bounds = string.substring(0, 1) + string.substring(string.length() - 1, string.length()); String margins = string.substring(1, string.length() - 1); Interval.Closure closure; switch(bounds){ case "[]": closure = Interval.Closure.CLOSED_CLOSED; break; default: throw new IllegalArgumentException(string); } String[] values = margins.split(":"); if(values.length != 2){ throw new IllegalArgumentException(margins); } Double leftMargin = Double.valueOf(values[0]); Double rightMargin = Double.valueOf(values[1]); Interval interval = new Interval(closure) .setLeftMargin(leftMargin) .setRightMargin(rightMargin); return interval; } static public List parseValues(String string){ String[] values = string.split(":"); return Stream.of(values) .map(LightGBMUtil.CATEGORY_PARSER) .collect(Collectors.toList()); } static public String unescape(String string){ if(string == null || !string.contains("\\u")){ return string; } StringBuffer sb = new StringBuffer(string.length()); Matcher matcher = LightGBMUtil.PATTERN_UNICODE_ESCAPE.matcher(string); while(matcher.find()){ int c = Integer.parseInt(matcher.group(1), 16); matcher.appendReplacement(sb, Character.toString((char)c)); } matcher.appendTail(sb); return sb.toString(); } private static final Pattern PATTERN_UNICODE_ESCAPE = Pattern.compile("\\\\u([0-9A-Fa-f]{4})"); static final Function CATEGORY_PARSER = new Function(){ @Override public Integer apply(String string){ try { return Integer.valueOf(string); } catch(NumberFormatException nfe){ return ValueUtil.asInteger(Double.valueOf(string)); } } }; static final Function CATEGORY_FORMATTER = new Function(){ @Override public String apply(Integer integer){ return integer.toString(); } }; }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy