All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.scalar.IpAddressFunctions Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.scalar;

import com.google.common.net.InetAddresses;
import com.google.common.primitives.Ints;
import io.airlift.slice.Slice;
import io.trino.spi.TrinoException;
import io.trino.spi.function.Description;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.type.StandardTypes;

import java.math.BigInteger;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.util.regex.Pattern;

import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;

public final class IpAddressFunctions
{
    private static final Pattern IPV4_PATTERN = Pattern.compile("^\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}$");

    private IpAddressFunctions() {}

    @Description("Determines whether given IP address exists in the CIDR")
    @ScalarFunction
    @SqlType(StandardTypes.BOOLEAN)
    public static boolean contains(@SqlType(StandardTypes.VARCHAR) Slice network, @SqlType(StandardTypes.IPADDRESS) Slice address)
    {
        String cidr = network.toStringUtf8();

        int separator = cidr.indexOf("/");
        if (separator == -1) {
            throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Invalid CIDR");
        }

        String cidrBase = cidr.substring(0, separator);
        InetAddress cidrAddress;
        try {
            cidrAddress = InetAddresses.forString(cidrBase);
        }
        catch (IllegalArgumentException e) {
            throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Invalid network IP address");
        }

        byte[] cidrBytes = toBytes(cidrAddress);

        int prefixLength = Integer.parseInt(cidr.substring(separator + 1));
        if (prefixLength < 0) {
            throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Invalid prefix length");
        }

        // We do regex match instead of instanceof Inet4Address because InetAddresses.forString() normalizes
        // IPv4 mapped IPv6 addresses (e.g., ::ffff:0102:0304) to Inet4Address. We need to be able to
        // distinguish between the two formats in the CIDR string to be able to interpret the prefix length correctly.
        if (IPV4_PATTERN.matcher(cidrBase).matches()) {
            if (!isValidIpV4Cidr(cidrBytes, 12, prefixLength)) {
                throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Invalid CIDR");
            }
            prefixLength += 96;
        }
        else if (!isValidIpV6Cidr(prefixLength)) {
            throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Invalid CIDR");
        }

        if (prefixLength == 0) {
            return true;
        }

        byte[] ipAddress = address.getBytes();
        BigInteger cidrPrefix = new BigInteger(cidrBytes).shiftRight(cidrBytes.length * Byte.SIZE - prefixLength);
        BigInteger addressPrefix = new BigInteger(ipAddress).shiftRight(ipAddress.length * Byte.SIZE - prefixLength);

        return cidrPrefix.equals(addressPrefix);
    }

    private static boolean isValidIpV6Cidr(int prefixLength)
    {
        return prefixLength >= 0 && prefixLength <= 128;
    }

    private static boolean isValidIpV4Cidr(byte[] address, int offset, int prefix)
    {
        if (prefix < 0 || prefix > 32) {
            return false;
        }

        long mask = 0xFFFFFFFFL >>> prefix;
        return (Ints.fromBytes(address[offset], address[offset + 1], address[offset + 2], address[offset + 3]) & mask) == 0;
    }

    private static byte[] toBytes(InetAddress address)
    {
        byte[] bytes = address.getAddress();

        if (address instanceof Inet4Address) {
            byte[] temp = new byte[16];
            // IPv4 mapped addresses are encoded as ::ffff:
temp[10] = (byte) 0xFF; temp[11] = (byte) 0xFF; System.arraycopy(bytes, 0, temp, 12, 4); bytes = temp; } return bytes; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy