From bdddeba71d12c3166f72cb84f965cbd2aa911cb2 Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipamc77@gmail.com>
Date: Mon, 22 Jan 2024 18:04:06 +0530
Subject: [PATCH] add controlnet baseline

---
 models/colors.py           | 344 +++++++++++++++++++++++++++++++++++++
 models/controlnet_model.py | 199 +++++++++++++++++++++
 models/palette.py          |  37 ++++
 models/user_config.py      |   4 +-
 models/utils.py            |  83 +++++++++
 5 files changed, 666 insertions(+), 1 deletion(-)
 create mode 100644 models/colors.py
 create mode 100644 models/controlnet_model.py
 create mode 100644 models/palette.py
 create mode 100644 models/utils.py

diff --git a/models/colors.py b/models/colors.py
new file mode 100644
index 0000000..e9d46ce
--- /dev/null
+++ b/models/colors.py
@@ -0,0 +1,344 @@
+"""Color mappings"""
+from typing import List
+
+TRIVIA = {
+    "#B47878": "building;edifice",
+    "#06E6E6": "sky",
+    "#04C803": "tree",
+    "#8C8C8C": "road;route",
+    "#04FA07": "grass",
+    "#96053D": "person;individual;someone;somebody;mortal;soul",
+    "#CCFF04": "plant;flora;plant;life",
+    "#787846": "earth;ground",
+    "#FF09E0": "house",
+    "#0066C8": "car;auto;automobile;machine;motorcar",
+    "#3DE6FA": "water",
+    "#FF3D06": "railing;rail",
+    "#FF5C00": "arcade;machine",
+    "#FFE000": "stairs;steps",
+    "#00F5FF": "fan",
+    "#FF008F": "step;stair",
+    "#1F00FF": "stairway;staircase",
+    "#FFD600": "radiator",
+}
+
+OBJECTS = {
+    "#CC05FF": "bed",
+    "#FF0633": "painting;picture",
+    "#DCDCDC": "mirror",
+    "#00FF14": "box",
+    "#FF0000": "flower",
+    "#FFA300": "book",
+    "#00FFC2": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box",
+    "#F500FF": "pot;flowerpot",
+    "#00FFCC": "vase",
+    "#29FF00": "tray",
+    "#8FFF00": "poster;posting;placard;notice;bill;card",
+    "#5CFF00": "basket;handbasket",
+    "#00ADFF": "screen;door;screen",
+}
+
+
+SITTING = {
+    "#0B66FF": "sofa;couch;lounge",
+    "#CC4603": "chair",
+    "#07FFE0": "seat",
+    "#08FFD6": "armchair",
+    "#FFC207": "cushion",
+    "#00EBFF": "pillow",
+    "#00D6FF": "stool",
+    "#1400FF": "blanket;cover",
+    "#0A00FF": "swivel;chair",
+    "#FF9900": "ottoman;pouf;pouffe;puff;hassock",
+}
+
+LIGHTING = {
+    "#E0FF08": "lamp",
+    "#FFAD00": "light;light;source",
+    "#001FFF": "chandelier;pendant;pendent",
+}
+
+TABLES = {
+    "#FF0652": "table",
+    "#0AFF47": "desk",
+}
+
+CLOSETS = {
+    "#E005FF": "cabinet",
+    "#FF0747": "shelf",
+    "#07FFFF": "wardrobe;closet;press",
+    "#0633FF": "chest;of;drawers;chest;bureau;dresser",
+    "#0000FF": "case;display;case;showcase;vitrine",
+}
+
+
+BATHROOM = {
+    "#6608FF": "bathtub;bathing;tub;bath;tub",
+    "#00FF85": "toilet;can;commode;crapper;pot;potty;stool;throne",
+    "#0085FF": "shower",
+    "#FF0066": "towel",
+}
+
+WINDOWS = {
+    "#FF3307": "curtain;drape;drapery;mantle;pall",
+    "#E6E6E6": "windowpane;window",
+    "#00FF3D": "awning;sunshade;sunblind",
+    "#003DFF": "blind;screen",
+}
+
+FLOOR = {
+    "#FF095C": "rug;carpet;carpeting",
+    "#503232": "floor;flooring",
+}
+
+INTERIOR = {
+    "#787878": "wall",
+    "#787850": "ceiling",
+    "#08FF33": "door;double;door",
+}
+
+KITCHEN = {
+    "#00FF29": "kitchen;island",
+    "#14FF00": "refrigerator;icebox",
+    "#00A3FF": "sink",
+    "#EB0CFF": "counter",
+    "#D6FF00": "dishwasher;dish;washer;dishwashing;machine",
+    "#FF00EB": "microwave;microwave;oven",
+    "#47FF00": "oven",
+    "#66FF00": "clock",
+    "#00FFB8": "plate",
+    "#19C2C2": "glass;drinking;glass",
+    "#00FF99": "bar",
+    "#00FF0A": "bottle",
+    "#FF7000": "buffet;counter;sideboard",
+    "#B800FF": "washer;automatic;washer;washing;machine",
+    "#00FF70": "coffee;table;cocktail;table",
+    "#008FFF": "countertop",
+    "#33FF00": "stove;kitchen;stove;range;kitchen;range;cooking;stove",
+}
+
+LIVINGROOM = {
+    "#FA0A0F": "fireplace;hearth;open;fireplace",
+    "#FF4700": "pool;table;billiard;table;snooker;table",
+}
+
+OFFICE = {
+    "#00FFAD": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system",
+    "#00FFF5": "bookcase",
+    "#0633FF": "chest;of;drawers;chest;bureau;dresser",
+    "#005CFF": "monitor;monitoring;device",
+}
+
+
+COLOR_MAPPING_CATEGORY_ = {
+ 'keep background': {'#FFFFFF': 'background'},   
+ 'trivia': TRIVIA,
+ 'objects': OBJECTS,
+ 'sitting': SITTING,
+ 'lighting': LIGHTING,
+ 'tables': TABLES,
+ 'closets': CLOSETS,
+ 'bathroom': BATHROOM,
+ 'windows': WINDOWS,
+ 'floor': FLOOR,
+ 'interior': INTERIOR,
+ 'kitchen': KITCHEN,
+ 'livingroom': LIVINGROOM,
+ 'office': OFFICE}
+
+
+COLOR_MAPPING_ = {
+    '#FFFFFF': 'background',
+    "#787878": "wall",
+    "#B47878": "building;edifice",
+    "#06E6E6": "sky",
+    "#503232": "floor;flooring",
+    "#04C803": "tree",
+    "#787850": "ceiling",
+    "#8C8C8C": "road;route",
+    "#CC05FF": "bed",
+    "#E6E6E6": "windowpane;window",
+    "#04FA07": "grass",
+    "#E005FF": "cabinet",
+    "#EBFF07": "sidewalk;pavement",
+    "#96053D": "person;individual;someone;somebody;mortal;soul",
+    "#787846": "earth;ground",
+    "#08FF33": "door;double;door",
+    "#FF0652": "table",
+    "#8FFF8C": "mountain;mount",
+    "#CCFF04": "plant;flora;plant;life",
+    "#FF3307": "curtain;drape;drapery;mantle;pall",
+    "#CC4603": "chair",
+    "#0066C8": "car;auto;automobile;machine;motorcar",
+    "#3DE6FA": "water",
+    "#FF0633": "painting;picture",
+    "#0B66FF": "sofa;couch;lounge",
+    "#FF0747": "shelf",
+    "#FF09E0": "house",
+    "#0907E6": "sea",
+    "#DCDCDC": "mirror",
+    "#FF095C": "rug;carpet;carpeting",
+    "#7009FF": "field",
+    "#08FFD6": "armchair",
+    "#07FFE0": "seat",
+    "#FFB806": "fence;fencing",
+    "#0AFF47": "desk",
+    "#FF290A": "rock;stone",
+    "#07FFFF": "wardrobe;closet;press",
+    "#E0FF08": "lamp",
+    "#6608FF": "bathtub;bathing;tub;bath;tub",
+    "#FF3D06": "railing;rail",
+    "#FFC207": "cushion",
+    "#FF7A08": "base;pedestal;stand",
+    "#00FF14": "box",
+    "#FF0829": "column;pillar",
+    "#FF0599": "signboard;sign",
+    "#0633FF": "chest;of;drawers;chest;bureau;dresser",
+    "#EB0CFF": "counter",
+    "#A09614": "sand",
+    "#00A3FF": "sink",
+    "#8C8C8C": "skyscraper",
+    "#FA0A0F": "fireplace;hearth;open;fireplace",
+    "#14FF00": "refrigerator;icebox",
+    "#1FFF00": "grandstand;covered;stand",
+    "#FF1F00": "path",
+    "#FFE000": "stairs;steps",
+    "#99FF00": "runway",
+    "#0000FF": "case;display;case;showcase;vitrine",
+    "#FF4700": "pool;table;billiard;table;snooker;table",
+    "#00EBFF": "pillow",
+    "#00ADFF": "screen;door;screen",
+    "#1F00FF": "stairway;staircase",
+    "#0BC8C8": "river",
+    "#FF5200": "bridge;span",
+    "#00FFF5": "bookcase",
+    "#003DFF": "blind;screen",
+    "#00FF70": "coffee;table;cocktail;table",
+    "#00FF85": "toilet;can;commode;crapper;pot;potty;stool;throne",
+    "#FF0000": "flower",
+    "#FFA300": "book",
+    "#FF6600": "hill",
+    "#C2FF00": "bench",
+    "#008FFF": "countertop",
+    "#33FF00": "stove;kitchen;stove;range;kitchen;range;cooking;stove",
+    "#0052FF": "palm;palm;tree",
+    "#00FF29": "kitchen;island",
+    "#00FFAD": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system",
+    "#0A00FF": "swivel;chair",
+    "#ADFF00": "boat",
+    "#00FF99": "bar",
+    "#FF5C00": "arcade;machine",
+    "#FF00FF": "hovel;hut;hutch;shack;shanty",
+    "#FF00F5": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle",
+    "#FF0066": "towel",
+    "#FFAD00": "light;light;source",
+    "#FF0014": "truck;motortruck",
+    "#FFB8B8": "tower",
+    "#001FFF": "chandelier;pendant;pendent",
+    "#00FF3D": "awning;sunshade;sunblind",
+    "#0047FF": "streetlight;street;lamp",
+    "#FF00CC": "booth;cubicle;stall;kiosk",
+    "#00FFC2": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box",
+    "#00FF52": "airplane;aeroplane;plane",
+    "#000AFF": "dirt;track",
+    "#0070FF": "apparel;wearing;apparel;dress;clothes",
+    "#3300FF": "pole",
+    "#00C2FF": "land;ground;soil",
+    "#007AFF": "bannister;banister;balustrade;balusters;handrail",
+    "#00FFA3": "escalator;moving;staircase;moving;stairway",
+    "#FF9900": "ottoman;pouf;pouffe;puff;hassock",
+    "#00FF0A": "bottle",
+    "#FF7000": "buffet;counter;sideboard",
+    "#8FFF00": "poster;posting;placard;notice;bill;card",
+    "#5200FF": "stage",
+    "#A3FF00": "van",
+    "#FFEB00": "ship",
+    "#08B8AA": "fountain",
+    "#8500FF": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter",
+    "#00FF5C": "canopy",
+    "#B800FF": "washer;automatic;washer;washing;machine",
+    "#FF001F": "plaything;toy",
+    "#00B8FF": "swimming;pool;swimming;bath;natatorium",
+    "#00D6FF": "stool",
+    "#FF0070": "barrel;cask",
+    "#5CFF00": "basket;handbasket",
+    "#00E0FF": "waterfall;falls",
+    "#70E0FF": "tent;collapsible;shelter",
+    "#46B8A0": "bag",
+    "#A300FF": "minibike;motorbike",
+    "#9900FF": "cradle",
+    "#47FF00": "oven",
+    "#FF00A3": "ball",
+    "#FFCC00": "food;solid;food",
+    "#FF008F": "step;stair",
+    "#00FFEB": "tank;storage;tank",
+    "#85FF00": "trade;name;brand;name;brand;marque",
+    "#FF00EB": "microwave;microwave;oven",
+    "#F500FF": "pot;flowerpot",
+    "#FF007A": "animal;animate;being;beast;brute;creature;fauna",
+    "#FFF500": "bicycle;bike;wheel;cycle",
+    "#0ABED4": "lake",
+    "#D6FF00": "dishwasher;dish;washer;dishwashing;machine",
+    "#00CCFF": "screen;silver;screen;projection;screen",
+    "#1400FF": "blanket;cover",
+    "#FFFF00": "sculpture",
+    "#0099FF": "hood;exhaust;hood",
+    "#0029FF": "sconce",
+    "#00FFCC": "vase",
+    "#2900FF": "traffic;light;traffic;signal;stoplight",
+    "#29FF00": "tray",
+    "#AD00FF": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin",
+    "#00F5FF": "fan",
+    "#4700FF": "pier;wharf;wharfage;dock",
+    "#7A00FF": "crt;screen",
+    "#00FFB8": "plate",
+    "#005CFF": "monitor;monitoring;device",
+    "#B8FF00": "bulletin;board;notice;board",
+    "#0085FF": "shower",
+    "#FFD600": "radiator",
+    "#19C2C2": "glass;drinking;glass",
+    "#66FF00": "clock",
+    "#5C00FF": "flag",
+}
+
+
+def ade_palette() -> List[List[int]]:
+    """ADE20K palette that maps each class to RGB values."""
+    return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+            [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+            [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+            [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+            [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+            [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+            [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+            [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+            [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+            [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+            [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+            [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+            [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+            [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+            [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+            [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+            [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+            [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+            [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+            [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+            [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+            [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+            [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+            [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+            [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+            [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+            [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+            [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+            [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+            [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+            [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+            [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+            [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+            [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+            [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+            [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+            [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+            [102, 255, 0], [92, 0, 255]]
diff --git a/models/controlnet_model.py b/models/controlnet_model.py
new file mode 100644
index 0000000..33fba16
--- /dev/null
+++ b/models/controlnet_model.py
@@ -0,0 +1,199 @@
+from typing import Tuple, Union, List
+
+import numpy as np
+from PIL import Image
+
+import torch
+from diffusers import ControlNetModel
+from diffusers.pipelines.controlnet import MultiControlNetModel, StableDiffusionControlNetInpaintPipeline
+from diffusers import ControlNetModel, UniPCMultistepScheduler
+from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
+
+from models.colors import ade_palette
+from models.utils import map_colors_rgb
+
+def filter_items(
+    colors_list: Union[List, np.ndarray],
+    items_list: Union[List, np.ndarray],
+    items_to_remove: Union[List, np.ndarray]
+) -> Tuple[Union[List, np.ndarray], Union[List, np.ndarray]]:
+    """
+    Filters items and their corresponding colors from given lists, excluding
+    specified items.
+
+    Args:
+        colors_list: A list or numpy array of colors corresponding to items.
+        items_list: A list or numpy array of items.
+        items_to_remove: A list or numpy array of items to be removed.
+
+    Returns:
+        A tuple of two lists or numpy arrays: filtered colors and filtered
+        items.
+    """
+    filtered_colors = []
+    filtered_items = []
+    for color, item in zip(colors_list, items_list):
+        if item not in items_to_remove:
+            filtered_colors.append(color)
+            filtered_items.append(item)
+    return filtered_colors, filtered_items
+
+def get_segmentation_pipeline(
+) -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
+    """Method to load the segmentation pipeline
+    Returns:
+        Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
+    """
+    image_processor = AutoImageProcessor.from_pretrained(
+        "openmmlab/upernet-convnext-small"
+    )
+    image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
+        "openmmlab/upernet-convnext-small"
+    )
+    return image_processor, image_segmentor
+
+
+@torch.inference_mode()
+@torch.autocast('cuda')
+def segment_image(
+        image: Image,
+        image_processor: AutoImageProcessor,
+        image_segmentor: UperNetForSemanticSegmentation
+) -> Image:
+    """
+    Segments an image using a semantic segmentation model.
+
+    Args:
+        image (Image): The input image to be segmented.
+        image_processor (AutoImageProcessor): The processor to prepare the
+            image for segmentation.
+        image_segmentor (UperNetForSemanticSegmentation): The semantic
+            segmentation model used to identify different segments in the image.
+
+    Returns:
+        Image: The segmented image with each segment colored differently based
+            on its identified class.
+    """
+    # image_processor, image_segmentor = get_segmentation_pipeline()
+    pixel_values = image_processor(image, return_tensors="pt").pixel_values
+    with torch.no_grad():
+        outputs = image_segmentor(pixel_values)
+
+    seg = image_processor.post_process_semantic_segmentation(
+        outputs, target_sizes=[image.size[::-1]])[0]
+    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+    palette = np.array(ade_palette())
+    for label, color in enumerate(palette):
+        color_seg[seg == label, :] = color
+    color_seg = color_seg.astype(np.uint8)
+    seg_image = Image.fromarray(color_seg).convert('RGB')
+    return seg_image
+
+
+def resize_dimensions(dimensions, target_size):
+    """ 
+    Resize PIL to target size while maintaining aspect ratio 
+    If smaller than target size leave it as is
+    """
+    width, height = dimensions
+
+    # Check if both dimensions are smaller than the target size
+    if width < target_size and height < target_size:
+        return dimensions
+
+    # Determine the larger side
+    if width > height:
+        # Calculate the aspect ratio
+        aspect_ratio = height / width
+        # Resize dimensions
+        return (target_size, int(target_size * aspect_ratio))
+    else:
+        # Calculate the aspect ratio
+        aspect_ratio = width / height
+        # Resize dimensions
+        return (int(target_size * aspect_ratio), target_size)
+
+
+
+class ControlNetDesignModel:
+    """ Produces random noise images """
+    def __init__(self):
+        """ Initialize your model(s) here """
+        controlnet_seg = ControlNetModel.from_pretrained(
+            "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)
+
+        self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
+            "runwayml/stable-diffusion-inpainting",
+            controlnet=controlnet_seg,
+            safety_checker=None,
+            torch_dtype=torch.float16
+        )
+
+        self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
+        self.pipe.enable_xformers_memory_efficient_attention()
+        self.pipe = self.pipe.to("cuda")
+
+        self.seed = 323
+        self.neg_prompt = "lowres, watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, out of focus, out of frame, surreal, ugly"
+        self.control_items = ["windowpane;window"]
+        self.additional_quality_suffix = "interior design, 4K, high resolution"
+
+        self.seg_image_processor, self.image_segmentor = get_segmentation_pipeline()
+
+    def generate_design(self, empty_room_image: Image, prompt: str) -> Image:
+        """
+        Given an image of an empty room and a prompt
+        generate the designed room according to the prompt
+        Inputs - 
+            empty_room_image - An RGB PIL Image of the empty room
+            prompt - Text describing the target design elements of the room
+        Returns - 
+            design_image - PIL Image of the same size as the empty room image
+                           If the size is not the same the submission will fail.
+        """
+        print(prompt)
+
+        pos_prompt = prompt + f', {self.additional_quality_suffix}'
+
+        orig_w, orig_h = empty_room_image.size
+        new_width, new_height = resize_dimensions(empty_room_image.size, 768)
+        input_image = empty_room_image.resize((new_width, new_height))
+        print((orig_w, orig_h), (new_width, new_height))
+        real_seg = np.array(segment_image(input_image,
+                                          self.seg_image_processor,
+                                          self.image_segmentor))
+        unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0)
+        unique_colors = [tuple(color) for color in unique_colors]
+        segment_items = [map_colors_rgb(i) for i in unique_colors]
+        chosen_colors, segment_items = filter_items(
+            colors_list=unique_colors,
+            items_list=segment_items,
+            items_to_remove=self.control_items
+        )
+        mask = np.zeros_like(real_seg)
+        for color in chosen_colors:
+            color_matches = (real_seg == color).all(axis=2)
+            mask[color_matches] = 1
+        
+        image_np = np.array(input_image)
+        image = Image.fromarray(image_np).convert("RGB")
+        segmentation_cond_image = Image.fromarray(real_seg).convert("RGB")
+        mask_image = Image.fromarray((mask * 255).astype(np.uint8)).convert("RGB")
+
+        generated_image = self.pipe(
+            prompt=pos_prompt,
+            negative_prompt=self.neg_prompt,
+            num_inference_steps=50,
+            strength=1,
+            guidance_scale=7,
+            generator=[torch.Generator(device="cuda").manual_seed(self.seed)],
+            image=image,
+            mask_image=mask_image,
+            control_image=segmentation_cond_image,
+        ).images[0]
+
+        design_image = generated_image.resize(
+            (orig_w, orig_h), Image.Resampling.LANCZOS
+        )
+        
+        return design_image
diff --git a/models/palette.py b/models/palette.py
new file mode 100644
index 0000000..c1efd32
--- /dev/null
+++ b/models/palette.py
@@ -0,0 +1,37 @@
+from typing import Dict
+from models.colors import COLOR_MAPPING_, COLOR_MAPPING_CATEGORY_
+
+
+def convert_hex_to_rgba(hex_code: str) -> str:
+    """Convert hex code to rgba.
+    Args:
+        hex_code (str): hex string
+    Returns:
+        str: rgba string
+    """
+    hex_code = hex_code.lstrip('#')
+    return "rgba(" + str(int(hex_code[0:2], 16)) + ", " + str(int(hex_code[2:4], 16)) + ", " + str(int(hex_code[4:6], 16)) + ", 1.0)"
+
+
+def convert_dict_to_rgba(color_dict: Dict) -> Dict:
+    """Convert hex code to rgba for all elements in a dictionary.
+    Args:
+        color_dict (Dict): color dictionary
+    Returns:
+        Dict: color dictionary with rgba values
+    """
+    updated_dict = {}
+    for k, v in color_dict.items():
+        updated_dict[convert_hex_to_rgba(k)] = v
+    return updated_dict
+
+
+def convert_nested_dict_to_rgba(nested_dict):
+    updated_dict = {}
+    for k, v in nested_dict.items():
+        updated_dict[k] = convert_dict_to_rgba(v)
+    return updated_dict
+
+
+COLOR_MAPPING = convert_dict_to_rgba(COLOR_MAPPING_)
+COLOR_MAPPING_CATEGORY = convert_nested_dict_to_rgba(COLOR_MAPPING_CATEGORY_)
diff --git a/models/user_config.py b/models/user_config.py
index bf5de71..176ce46 100644
--- a/models/user_config.py
+++ b/models/user_config.py
@@ -1,3 +1,5 @@
 from models.random_model import RandomModel
+from models.controlnet_model import ControlNetDesignModel
 
-UserModel = RandomModel
\ No newline at end of file
+UserModel = RandomModel
+# UserModel = ControlNetDesignModel
\ No newline at end of file
diff --git a/models/utils.py b/models/utils.py
new file mode 100644
index 0000000..f6725e2
--- /dev/null
+++ b/models/utils.py
@@ -0,0 +1,83 @@
+import gc
+
+import numpy as np
+from PIL import Image
+import torch
+from scipy.signal import fftconvolve
+
+from models.palette import COLOR_MAPPING, COLOR_MAPPING_
+
+
+def to_rgb(color: str) -> tuple:
+    """Convert hex color to rgb.
+    Args:
+        color (str): hex color
+    Returns:
+        tuple: rgb color
+    """
+    return tuple(int(color[i:i+2], 16) for i in (1, 3, 5))
+
+
+def map_colors(color: str) -> str:
+    """Map color to hex value.
+    Args:
+        color (str): color name
+    Returns:
+        str: hex value
+    """
+    return COLOR_MAPPING[color]
+
+
+def map_colors_rgb(color: tuple) -> str:
+    return COLOR_MAPPING_RGB[color]
+
+
+def convolution(mask: Image.Image, size=9) -> Image:
+    """Method to blur the mask
+    Args:
+        mask (Image): masking image
+        size (int, optional): size of the blur. Defaults to 9.
+    Returns:
+        Image: blurred mask
+    """
+    mask = np.array(mask.convert("L"))
+    conv = np.ones((size, size)) / size**2
+    mask_blended = fftconvolve(mask, conv, 'same')
+    mask_blended = mask_blended.astype(np.uint8).copy()
+
+    border = size
+
+    # replace borders with original values
+    mask_blended[:border, :] = mask[:border, :]
+    mask_blended[-border:, :] = mask[-border:, :]
+    mask_blended[:, :border] = mask[:, :border]
+    mask_blended[:, -border:] = mask[:, -border:]
+
+    return Image.fromarray(mask_blended).convert("L")
+
+
+def flush():
+    gc.collect()
+    torch.cuda.empty_cache()
+
+
+def postprocess_image_masking(inpainted: Image, image: Image,
+                              mask: Image) -> Image:
+    """Method to postprocess the inpainted image
+    Args:
+        inpainted (Image): inpainted image
+        image (Image): original image
+        mask (Image): mask
+    Returns:
+        Image: inpainted image
+    """
+    final_inpainted = Image.composite(inpainted.convert("RGBA"),
+                                      image.convert("RGBA"), mask)
+    return final_inpainted.convert("RGB")
+
+
+COLOR_NAMES = list(COLOR_MAPPING.keys())
+COLOR_RGB = [to_rgb(k) for k in COLOR_MAPPING_.keys()] + [(0, 0, 0),
+                                                          (255, 255, 255)]
+INVERSE_COLORS = {v: to_rgb(k) for k, v in COLOR_MAPPING_.items()}
+COLOR_MAPPING_RGB = {to_rgb(k): v for k, v in COLOR_MAPPING_.items()}
-- 
GitLab