package io.dogboy.serializationisbad.core;

import io.dogboy.serializationisbad.core.config.PatchModule;
import io.dogboy.serializationisbad.shadow.asm.ClassReader;
import io.dogboy.serializationisbad.shadow.asm.ClassWriter;
import io.dogboy.serializationisbad.shadow.asm.Opcodes;
import io.dogboy.serializationisbad.shadow.asm.tree.AbstractInsnNode;
import io.dogboy.serializationisbad.shadow.asm.tree.ClassNode;
import io.dogboy.serializationisbad.shadow.asm.tree.InsnList;
import io.dogboy.serializationisbad.shadow.asm.tree.LdcInsnNode;
import io.dogboy.serializationisbad.shadow.asm.tree.MethodInsnNode;
import io.dogboy.serializationisbad.shadow.asm.tree.MethodNode;
import io.dogboy.serializationisbad.shadow.asm.tree.TypeInsnNode;
import io.dogboy.serializationisbad.shadow.asm.tree.VarInsnNode;
import java.io.ByteArrayInputStream;
import java.io.InputStream;

/* loaded from: input_file:io/dogboy/serializationisbad/core/Patches.class */
public class Patches {
    public static PatchModule getPatchModuleForClass(String str) {
        for (PatchModule patchModule : SerializationIsBad.getInstance().getConfig().getPatchModules()) {
            if (patchModule.getClassesToPatch().contains(str)) {
                return patchModule;
            }
        }
        return null;
    }

    private static ClassNode readClassNode(byte[] bArr) {
        ClassNode classNode = new ClassNode();
        new ClassReader(bArr).accept(classNode, 0);
        return classNode;
    }

    private static byte[] writeClassNode(ClassNode classNode) {
        ClassWriter classWriter = new ClassWriter(1);
        classNode.accept(classWriter);
        return classWriter.toByteArray();
    }

    private static void applyPatches(String str, ClassNode classNode, boolean z) {
        SerializationIsBad.logger.info("Applying patches to " + str);
        for (MethodNode methodNode : classNode.methods) {
            InsnList insnList = methodNode.instructions;
            for (int i = 0; i < insnList.size(); i++) {
                AbstractInsnNode abstractInsnNode = insnList.get(i);
                if (abstractInsnNode.getOpcode() == 187 && (abstractInsnNode instanceof TypeInsnNode) && "java/io/ObjectInputStream".equals(((TypeInsnNode) abstractInsnNode).desc)) {
                    ((TypeInsnNode) abstractInsnNode).desc = "io/dogboy/serializationisbad/core/ClassFilteringObjectInputStream";
                    SerializationIsBad.logger.info("  (1/2) Redirecting ObjectInputStream to ClassFilteringObjectInputStream in method " + methodNode.name);
                } else if (abstractInsnNode.getOpcode() == 183 && (abstractInsnNode instanceof MethodInsnNode) && "java/io/ObjectInputStream".equals(((MethodInsnNode) abstractInsnNode).owner) && "<init>".equals(((MethodInsnNode) abstractInsnNode).name)) {
                    ((MethodInsnNode) abstractInsnNode).owner = "io/dogboy/serializationisbad/core/ClassFilteringObjectInputStream";
                    ((MethodInsnNode) abstractInsnNode).desc = "(Ljava/io/InputStream;Lio/dogboy/serializationisbad/core/config/PatchModule;)V";
                    if (z) {
                        ((MethodInsnNode) abstractInsnNode).desc = "(Ljava/io/InputStream;Lio/dogboy/serializationisbad/core/config/PatchModule;Ljava/lang/ClassLoader;)V";
                    }
                    InsnList insnList2 = new InsnList();
                    insnList2.add(new LdcInsnNode(str));
                    insnList2.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "io/dogboy/serializationisbad/core/Patches", "getPatchModuleForClass", "(Ljava/lang/String;)Lio/dogboy/serializationisbad/core/config/PatchModule;", false));
                    if (z) {
                        insnList2.add(new VarInsnNode(25, 0));
                        insnList2.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "getClass", "()Ljava/lang/Class;", false));
                        insnList2.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, "java/lang/Class", "getClassLoader", "()Ljava/lang/ClassLoader;", false));
                    }
                    insnList.insertBefore(abstractInsnNode, insnList2);
                    SerializationIsBad.logger.info("  (2/2) Redirecting ObjectInputStream to ClassFilteringObjectInputStream in method " + methodNode.name);
                } else if (abstractInsnNode.getOpcode() == 184 && (abstractInsnNode instanceof MethodInsnNode) && "org/apache/commons/lang3/SerializationUtils".equals(((MethodInsnNode) abstractInsnNode).owner) && "deserialize".equals(((MethodInsnNode) abstractInsnNode).name)) {
                    ((MethodInsnNode) abstractInsnNode).owner = "io/dogboy/serializationisbad/core/Patches";
                    if ("(Ljava/io/InputStream;)Ljava/lang/Object;".equals(((MethodInsnNode) abstractInsnNode).desc)) {
                        ((MethodInsnNode) abstractInsnNode).desc = "(Ljava/io/InputStream;Lio/dogboy/serializationisbad/core/config/PatchModule;)Ljava/lang/Object;";
                    } else {
                        if (!"([B)Ljava/lang/Object;".equals(((MethodInsnNode) abstractInsnNode).desc)) {
                            throw new RuntimeException("Unknown desc for SerializationUtils.deserialize: " + ((MethodInsnNode) abstractInsnNode).desc);
                        }
                        ((MethodInsnNode) abstractInsnNode).desc = "([BLio/dogboy/serializationisbad/core/config/PatchModule;)Ljava/lang/Object;";
                    }
                    InsnList insnList3 = new InsnList();
                    insnList3.add(new LdcInsnNode(str));
                    insnList3.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "io/dogboy/serializationisbad/core/Patches", "getPatchModuleForClass", "(Ljava/lang/String;)Lio/dogboy/serializationisbad/core/config/PatchModule;", false));
                    insnList.insertBefore(abstractInsnNode, insnList3);
                    SerializationIsBad.logger.info("  Redirecting SerializationUtils.deserialize to Patches in method " + methodNode.name);
                }
            }
        }
    }

    public static byte[] patchClass(byte[] bArr, String str, boolean z) {
        ClassNode readClassNode = readClassNode(bArr);
        applyPatches(str, readClassNode, z);
        return writeClassNode(readClassNode);
    }

    public static <T> T deserialize(InputStream inputStream, PatchModule patchModule) {
        try {
            ClassFilteringObjectInputStream classFilteringObjectInputStream = new ClassFilteringObjectInputStream(inputStream, patchModule);
            try {
                T t = (T) classFilteringObjectInputStream.readObject();
                classFilteringObjectInputStream.close();
                return t;
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static <T> T deserialize(byte[] bArr, PatchModule patchModule) {
        return (T) deserialize(new ByteArrayInputStream(bArr), patchModule);
    }
}
