From bdddeba71d12c3166f72cb84f965cbd2aa911cb2 Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipamc77@gmail.com> Date: Mon, 22 Jan 2024 18:04:06 +0530 Subject: [PATCH] add controlnet baseline --- models/colors.py | 344 +++++++++++++++++++++++++++++++++++++ models/controlnet_model.py | 199 +++++++++++++++++++++ models/palette.py | 37 ++++ models/user_config.py | 4 +- models/utils.py | 83 +++++++++ 5 files changed, 666 insertions(+), 1 deletion(-) create mode 100644 models/colors.py create mode 100644 models/controlnet_model.py create mode 100644 models/palette.py create mode 100644 models/utils.py diff --git a/models/colors.py b/models/colors.py new file mode 100644 index 0000000..e9d46ce --- /dev/null +++ b/models/colors.py @@ -0,0 +1,344 @@ +"""Color mappings""" +from typing import List + +TRIVIA = { + "#B47878": "building;edifice", + "#06E6E6": "sky", + "#04C803": "tree", + "#8C8C8C": "road;route", + "#04FA07": "grass", + "#96053D": "person;individual;someone;somebody;mortal;soul", + "#CCFF04": "plant;flora;plant;life", + "#787846": "earth;ground", + "#FF09E0": "house", + "#0066C8": "car;auto;automobile;machine;motorcar", + "#3DE6FA": "water", + "#FF3D06": "railing;rail", + "#FF5C00": "arcade;machine", + "#FFE000": "stairs;steps", + "#00F5FF": "fan", + "#FF008F": "step;stair", + "#1F00FF": "stairway;staircase", + "#FFD600": "radiator", +} + +OBJECTS = { + "#CC05FF": "bed", + "#FF0633": "painting;picture", + "#DCDCDC": "mirror", + "#00FF14": "box", + "#FF0000": "flower", + "#FFA300": "book", + "#00FFC2": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box", + "#F500FF": "pot;flowerpot", + "#00FFCC": "vase", + "#29FF00": "tray", + "#8FFF00": "poster;posting;placard;notice;bill;card", + "#5CFF00": "basket;handbasket", + "#00ADFF": "screen;door;screen", +} + + +SITTING = { + "#0B66FF": "sofa;couch;lounge", + "#CC4603": "chair", + "#07FFE0": "seat", + "#08FFD6": "armchair", + "#FFC207": "cushion", + "#00EBFF": "pillow", + "#00D6FF": "stool", + "#1400FF": "blanket;cover", + "#0A00FF": "swivel;chair", + "#FF9900": "ottoman;pouf;pouffe;puff;hassock", +} + +LIGHTING = { + "#E0FF08": "lamp", + "#FFAD00": "light;light;source", + "#001FFF": "chandelier;pendant;pendent", +} + +TABLES = { + "#FF0652": "table", + "#0AFF47": "desk", +} + +CLOSETS = { + "#E005FF": "cabinet", + "#FF0747": "shelf", + "#07FFFF": "wardrobe;closet;press", + "#0633FF": "chest;of;drawers;chest;bureau;dresser", + "#0000FF": "case;display;case;showcase;vitrine", +} + + +BATHROOM = { + "#6608FF": "bathtub;bathing;tub;bath;tub", + "#00FF85": "toilet;can;commode;crapper;pot;potty;stool;throne", + "#0085FF": "shower", + "#FF0066": "towel", +} + +WINDOWS = { + "#FF3307": "curtain;drape;drapery;mantle;pall", + "#E6E6E6": "windowpane;window", + "#00FF3D": "awning;sunshade;sunblind", + "#003DFF": "blind;screen", +} + +FLOOR = { + "#FF095C": "rug;carpet;carpeting", + "#503232": "floor;flooring", +} + +INTERIOR = { + "#787878": "wall", + "#787850": "ceiling", + "#08FF33": "door;double;door", +} + +KITCHEN = { + "#00FF29": "kitchen;island", + "#14FF00": "refrigerator;icebox", + "#00A3FF": "sink", + "#EB0CFF": "counter", + "#D6FF00": "dishwasher;dish;washer;dishwashing;machine", + "#FF00EB": "microwave;microwave;oven", + "#47FF00": "oven", + "#66FF00": "clock", + "#00FFB8": "plate", + "#19C2C2": "glass;drinking;glass", + "#00FF99": "bar", + "#00FF0A": "bottle", + "#FF7000": "buffet;counter;sideboard", + "#B800FF": "washer;automatic;washer;washing;machine", + "#00FF70": "coffee;table;cocktail;table", + "#008FFF": "countertop", + "#33FF00": "stove;kitchen;stove;range;kitchen;range;cooking;stove", +} + +LIVINGROOM = { + "#FA0A0F": "fireplace;hearth;open;fireplace", + "#FF4700": "pool;table;billiard;table;snooker;table", +} + +OFFICE = { + "#00FFAD": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system", + "#00FFF5": "bookcase", + "#0633FF": "chest;of;drawers;chest;bureau;dresser", + "#005CFF": "monitor;monitoring;device", +} + + +COLOR_MAPPING_CATEGORY_ = { + 'keep background': {'#FFFFFF': 'background'}, + 'trivia': TRIVIA, + 'objects': OBJECTS, + 'sitting': SITTING, + 'lighting': LIGHTING, + 'tables': TABLES, + 'closets': CLOSETS, + 'bathroom': BATHROOM, + 'windows': WINDOWS, + 'floor': FLOOR, + 'interior': INTERIOR, + 'kitchen': KITCHEN, + 'livingroom': LIVINGROOM, + 'office': OFFICE} + + +COLOR_MAPPING_ = { + '#FFFFFF': 'background', + "#787878": "wall", + "#B47878": "building;edifice", + "#06E6E6": "sky", + "#503232": "floor;flooring", + "#04C803": "tree", + "#787850": "ceiling", + "#8C8C8C": "road;route", + "#CC05FF": "bed", + "#E6E6E6": "windowpane;window", + "#04FA07": "grass", + "#E005FF": "cabinet", + "#EBFF07": "sidewalk;pavement", + "#96053D": "person;individual;someone;somebody;mortal;soul", + "#787846": "earth;ground", + "#08FF33": "door;double;door", + "#FF0652": "table", + "#8FFF8C": "mountain;mount", + "#CCFF04": "plant;flora;plant;life", + "#FF3307": "curtain;drape;drapery;mantle;pall", + "#CC4603": "chair", + "#0066C8": "car;auto;automobile;machine;motorcar", + "#3DE6FA": "water", + "#FF0633": "painting;picture", + "#0B66FF": "sofa;couch;lounge", + "#FF0747": "shelf", + "#FF09E0": "house", + "#0907E6": "sea", + "#DCDCDC": "mirror", + "#FF095C": "rug;carpet;carpeting", + "#7009FF": "field", + "#08FFD6": "armchair", + "#07FFE0": "seat", + "#FFB806": "fence;fencing", + "#0AFF47": "desk", + "#FF290A": "rock;stone", + "#07FFFF": "wardrobe;closet;press", + "#E0FF08": "lamp", + "#6608FF": "bathtub;bathing;tub;bath;tub", + "#FF3D06": "railing;rail", + "#FFC207": "cushion", + "#FF7A08": "base;pedestal;stand", + "#00FF14": "box", + "#FF0829": "column;pillar", + "#FF0599": "signboard;sign", + "#0633FF": "chest;of;drawers;chest;bureau;dresser", + "#EB0CFF": "counter", + "#A09614": "sand", + "#00A3FF": "sink", + "#8C8C8C": "skyscraper", + "#FA0A0F": "fireplace;hearth;open;fireplace", + "#14FF00": "refrigerator;icebox", + "#1FFF00": "grandstand;covered;stand", + "#FF1F00": "path", + "#FFE000": "stairs;steps", + "#99FF00": "runway", + "#0000FF": "case;display;case;showcase;vitrine", + "#FF4700": "pool;table;billiard;table;snooker;table", + "#00EBFF": "pillow", + "#00ADFF": "screen;door;screen", + "#1F00FF": "stairway;staircase", + "#0BC8C8": "river", + "#FF5200": "bridge;span", + "#00FFF5": "bookcase", + "#003DFF": "blind;screen", + "#00FF70": "coffee;table;cocktail;table", + "#00FF85": "toilet;can;commode;crapper;pot;potty;stool;throne", + "#FF0000": "flower", + "#FFA300": "book", + "#FF6600": "hill", + "#C2FF00": "bench", + "#008FFF": "countertop", + "#33FF00": "stove;kitchen;stove;range;kitchen;range;cooking;stove", + "#0052FF": "palm;palm;tree", + "#00FF29": "kitchen;island", + "#00FFAD": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system", + "#0A00FF": "swivel;chair", + "#ADFF00": "boat", + "#00FF99": "bar", + "#FF5C00": "arcade;machine", + "#FF00FF": "hovel;hut;hutch;shack;shanty", + "#FF00F5": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle", + "#FF0066": "towel", + "#FFAD00": "light;light;source", + "#FF0014": "truck;motortruck", + "#FFB8B8": "tower", + "#001FFF": "chandelier;pendant;pendent", + "#00FF3D": "awning;sunshade;sunblind", + "#0047FF": "streetlight;street;lamp", + "#FF00CC": "booth;cubicle;stall;kiosk", + "#00FFC2": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box", + "#00FF52": "airplane;aeroplane;plane", + "#000AFF": "dirt;track", + "#0070FF": "apparel;wearing;apparel;dress;clothes", + "#3300FF": "pole", + "#00C2FF": "land;ground;soil", + "#007AFF": "bannister;banister;balustrade;balusters;handrail", + "#00FFA3": "escalator;moving;staircase;moving;stairway", + "#FF9900": "ottoman;pouf;pouffe;puff;hassock", + "#00FF0A": "bottle", + "#FF7000": "buffet;counter;sideboard", + "#8FFF00": "poster;posting;placard;notice;bill;card", + "#5200FF": "stage", + "#A3FF00": "van", + "#FFEB00": "ship", + "#08B8AA": "fountain", + "#8500FF": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter", + "#00FF5C": "canopy", + "#B800FF": "washer;automatic;washer;washing;machine", + "#FF001F": "plaything;toy", + "#00B8FF": "swimming;pool;swimming;bath;natatorium", + "#00D6FF": "stool", + "#FF0070": "barrel;cask", + "#5CFF00": "basket;handbasket", + "#00E0FF": "waterfall;falls", + "#70E0FF": "tent;collapsible;shelter", + "#46B8A0": "bag", + "#A300FF": "minibike;motorbike", + "#9900FF": "cradle", + "#47FF00": "oven", + "#FF00A3": "ball", + "#FFCC00": "food;solid;food", + "#FF008F": "step;stair", + "#00FFEB": "tank;storage;tank", + "#85FF00": "trade;name;brand;name;brand;marque", + "#FF00EB": "microwave;microwave;oven", + "#F500FF": "pot;flowerpot", + "#FF007A": "animal;animate;being;beast;brute;creature;fauna", + "#FFF500": "bicycle;bike;wheel;cycle", + "#0ABED4": "lake", + "#D6FF00": "dishwasher;dish;washer;dishwashing;machine", + "#00CCFF": "screen;silver;screen;projection;screen", + "#1400FF": "blanket;cover", + "#FFFF00": "sculpture", + "#0099FF": "hood;exhaust;hood", + "#0029FF": "sconce", + "#00FFCC": "vase", + "#2900FF": "traffic;light;traffic;signal;stoplight", + "#29FF00": "tray", + "#AD00FF": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin", + "#00F5FF": "fan", + "#4700FF": "pier;wharf;wharfage;dock", + "#7A00FF": "crt;screen", + "#00FFB8": "plate", + "#005CFF": "monitor;monitoring;device", + "#B8FF00": "bulletin;board;notice;board", + "#0085FF": "shower", + "#FFD600": "radiator", + "#19C2C2": "glass;drinking;glass", + "#66FF00": "clock", + "#5C00FF": "flag", +} + + +def ade_palette() -> List[List[int]]: + """ADE20K palette that maps each class to RGB values.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] diff --git a/models/controlnet_model.py b/models/controlnet_model.py new file mode 100644 index 0000000..33fba16 --- /dev/null +++ b/models/controlnet_model.py @@ -0,0 +1,199 @@ +from typing import Tuple, Union, List + +import numpy as np +from PIL import Image + +import torch +from diffusers import ControlNetModel +from diffusers.pipelines.controlnet import MultiControlNetModel, StableDiffusionControlNetInpaintPipeline +from diffusers import ControlNetModel, UniPCMultistepScheduler +from transformers import AutoImageProcessor, UperNetForSemanticSegmentation + +from models.colors import ade_palette +from models.utils import map_colors_rgb + +def filter_items( + colors_list: Union[List, np.ndarray], + items_list: Union[List, np.ndarray], + items_to_remove: Union[List, np.ndarray] +) -> Tuple[Union[List, np.ndarray], Union[List, np.ndarray]]: + """ + Filters items and their corresponding colors from given lists, excluding + specified items. + + Args: + colors_list: A list or numpy array of colors corresponding to items. + items_list: A list or numpy array of items. + items_to_remove: A list or numpy array of items to be removed. + + Returns: + A tuple of two lists or numpy arrays: filtered colors and filtered + items. + """ + filtered_colors = [] + filtered_items = [] + for color, item in zip(colors_list, items_list): + if item not in items_to_remove: + filtered_colors.append(color) + filtered_items.append(item) + return filtered_colors, filtered_items + +def get_segmentation_pipeline( +) -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: + """Method to load the segmentation pipeline + Returns: + Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline + """ + image_processor = AutoImageProcessor.from_pretrained( + "openmmlab/upernet-convnext-small" + ) + image_segmentor = UperNetForSemanticSegmentation.from_pretrained( + "openmmlab/upernet-convnext-small" + ) + return image_processor, image_segmentor + + +@torch.inference_mode() +@torch.autocast('cuda') +def segment_image( + image: Image, + image_processor: AutoImageProcessor, + image_segmentor: UperNetForSemanticSegmentation +) -> Image: + """ + Segments an image using a semantic segmentation model. + + Args: + image (Image): The input image to be segmented. + image_processor (AutoImageProcessor): The processor to prepare the + image for segmentation. + image_segmentor (UperNetForSemanticSegmentation): The semantic + segmentation model used to identify different segments in the image. + + Returns: + Image: The segmented image with each segment colored differently based + on its identified class. + """ + # image_processor, image_segmentor = get_segmentation_pipeline() + pixel_values = image_processor(image, return_tensors="pt").pixel_values + with torch.no_grad(): + outputs = image_segmentor(pixel_values) + + seg = image_processor.post_process_semantic_segmentation( + outputs, target_sizes=[image.size[::-1]])[0] + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + palette = np.array(ade_palette()) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + color_seg = color_seg.astype(np.uint8) + seg_image = Image.fromarray(color_seg).convert('RGB') + return seg_image + + +def resize_dimensions(dimensions, target_size): + """ + Resize PIL to target size while maintaining aspect ratio + If smaller than target size leave it as is + """ + width, height = dimensions + + # Check if both dimensions are smaller than the target size + if width < target_size and height < target_size: + return dimensions + + # Determine the larger side + if width > height: + # Calculate the aspect ratio + aspect_ratio = height / width + # Resize dimensions + return (target_size, int(target_size * aspect_ratio)) + else: + # Calculate the aspect ratio + aspect_ratio = width / height + # Resize dimensions + return (int(target_size * aspect_ratio), target_size) + + + +class ControlNetDesignModel: + """ Produces random noise images """ + def __init__(self): + """ Initialize your model(s) here """ + controlnet_seg = ControlNetModel.from_pretrained( + "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16) + + self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", + controlnet=controlnet_seg, + safety_checker=None, + torch_dtype=torch.float16 + ) + + self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) + self.pipe.enable_xformers_memory_efficient_attention() + self.pipe = self.pipe.to("cuda") + + self.seed = 323 + self.neg_prompt = "lowres, watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, out of focus, out of frame, surreal, ugly" + self.control_items = ["windowpane;window"] + self.additional_quality_suffix = "interior design, 4K, high resolution" + + self.seg_image_processor, self.image_segmentor = get_segmentation_pipeline() + + def generate_design(self, empty_room_image: Image, prompt: str) -> Image: + """ + Given an image of an empty room and a prompt + generate the designed room according to the prompt + Inputs - + empty_room_image - An RGB PIL Image of the empty room + prompt - Text describing the target design elements of the room + Returns - + design_image - PIL Image of the same size as the empty room image + If the size is not the same the submission will fail. + """ + print(prompt) + + pos_prompt = prompt + f', {self.additional_quality_suffix}' + + orig_w, orig_h = empty_room_image.size + new_width, new_height = resize_dimensions(empty_room_image.size, 768) + input_image = empty_room_image.resize((new_width, new_height)) + print((orig_w, orig_h), (new_width, new_height)) + real_seg = np.array(segment_image(input_image, + self.seg_image_processor, + self.image_segmentor)) + unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0) + unique_colors = [tuple(color) for color in unique_colors] + segment_items = [map_colors_rgb(i) for i in unique_colors] + chosen_colors, segment_items = filter_items( + colors_list=unique_colors, + items_list=segment_items, + items_to_remove=self.control_items + ) + mask = np.zeros_like(real_seg) + for color in chosen_colors: + color_matches = (real_seg == color).all(axis=2) + mask[color_matches] = 1 + + image_np = np.array(input_image) + image = Image.fromarray(image_np).convert("RGB") + segmentation_cond_image = Image.fromarray(real_seg).convert("RGB") + mask_image = Image.fromarray((mask * 255).astype(np.uint8)).convert("RGB") + + generated_image = self.pipe( + prompt=pos_prompt, + negative_prompt=self.neg_prompt, + num_inference_steps=50, + strength=1, + guidance_scale=7, + generator=[torch.Generator(device="cuda").manual_seed(self.seed)], + image=image, + mask_image=mask_image, + control_image=segmentation_cond_image, + ).images[0] + + design_image = generated_image.resize( + (orig_w, orig_h), Image.Resampling.LANCZOS + ) + + return design_image diff --git a/models/palette.py b/models/palette.py new file mode 100644 index 0000000..c1efd32 --- /dev/null +++ b/models/palette.py @@ -0,0 +1,37 @@ +from typing import Dict +from models.colors import COLOR_MAPPING_, COLOR_MAPPING_CATEGORY_ + + +def convert_hex_to_rgba(hex_code: str) -> str: + """Convert hex code to rgba. + Args: + hex_code (str): hex string + Returns: + str: rgba string + """ + hex_code = hex_code.lstrip('#') + return "rgba(" + str(int(hex_code[0:2], 16)) + ", " + str(int(hex_code[2:4], 16)) + ", " + str(int(hex_code[4:6], 16)) + ", 1.0)" + + +def convert_dict_to_rgba(color_dict: Dict) -> Dict: + """Convert hex code to rgba for all elements in a dictionary. + Args: + color_dict (Dict): color dictionary + Returns: + Dict: color dictionary with rgba values + """ + updated_dict = {} + for k, v in color_dict.items(): + updated_dict[convert_hex_to_rgba(k)] = v + return updated_dict + + +def convert_nested_dict_to_rgba(nested_dict): + updated_dict = {} + for k, v in nested_dict.items(): + updated_dict[k] = convert_dict_to_rgba(v) + return updated_dict + + +COLOR_MAPPING = convert_dict_to_rgba(COLOR_MAPPING_) +COLOR_MAPPING_CATEGORY = convert_nested_dict_to_rgba(COLOR_MAPPING_CATEGORY_) diff --git a/models/user_config.py b/models/user_config.py index bf5de71..176ce46 100644 --- a/models/user_config.py +++ b/models/user_config.py @@ -1,3 +1,5 @@ from models.random_model import RandomModel +from models.controlnet_model import ControlNetDesignModel -UserModel = RandomModel \ No newline at end of file +UserModel = RandomModel +# UserModel = ControlNetDesignModel \ No newline at end of file diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000..f6725e2 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,83 @@ +import gc + +import numpy as np +from PIL import Image +import torch +from scipy.signal import fftconvolve + +from models.palette import COLOR_MAPPING, COLOR_MAPPING_ + + +def to_rgb(color: str) -> tuple: + """Convert hex color to rgb. + Args: + color (str): hex color + Returns: + tuple: rgb color + """ + return tuple(int(color[i:i+2], 16) for i in (1, 3, 5)) + + +def map_colors(color: str) -> str: + """Map color to hex value. + Args: + color (str): color name + Returns: + str: hex value + """ + return COLOR_MAPPING[color] + + +def map_colors_rgb(color: tuple) -> str: + return COLOR_MAPPING_RGB[color] + + +def convolution(mask: Image.Image, size=9) -> Image: + """Method to blur the mask + Args: + mask (Image): masking image + size (int, optional): size of the blur. Defaults to 9. + Returns: + Image: blurred mask + """ + mask = np.array(mask.convert("L")) + conv = np.ones((size, size)) / size**2 + mask_blended = fftconvolve(mask, conv, 'same') + mask_blended = mask_blended.astype(np.uint8).copy() + + border = size + + # replace borders with original values + mask_blended[:border, :] = mask[:border, :] + mask_blended[-border:, :] = mask[-border:, :] + mask_blended[:, :border] = mask[:, :border] + mask_blended[:, -border:] = mask[:, -border:] + + return Image.fromarray(mask_blended).convert("L") + + +def flush(): + gc.collect() + torch.cuda.empty_cache() + + +def postprocess_image_masking(inpainted: Image, image: Image, + mask: Image) -> Image: + """Method to postprocess the inpainted image + Args: + inpainted (Image): inpainted image + image (Image): original image + mask (Image): mask + Returns: + Image: inpainted image + """ + final_inpainted = Image.composite(inpainted.convert("RGBA"), + image.convert("RGBA"), mask) + return final_inpainted.convert("RGB") + + +COLOR_NAMES = list(COLOR_MAPPING.keys()) +COLOR_RGB = [to_rgb(k) for k in COLOR_MAPPING_.keys()] + [(0, 0, 0), + (255, 255, 255)] +INVERSE_COLORS = {v: to_rgb(k) for k, v in COLOR_MAPPING_.items()} +COLOR_MAPPING_RGB = {to_rgb(k): v for k, v in COLOR_MAPPING_.items()} -- GitLab